package info.kwarc.mmt.frameit.rules

import info.kwarc.mmt.api.checking._
import info.kwarc.mmt.api.{DPath, GlobalName, LocalName, ParametricRule, Rule, RuleSet}
import info.kwarc.mmt.api.frontend.Controller
import info.kwarc.mmt.api.objects._
import Conversions._
import info.kwarc.mmt.api.uom.{Simplifiability, Simplify}
import info.kwarc.mmt.lf.{Apply, ApplySpine, Arrow, OfType, Pi}
import info.kwarc.mmt.LFX.datatypes.{Listtype, Ls}
import info.kwarc.mmt.api.utils.URI

import scala.collection.mutable

object StepUntilX {
  val ns: DPath = DPath(URI.http colon "mathhub.info") / "FrameIT" / "frameworld"
  val path: GlobalName = ns ? "StepUntilX" ? "stepUntilX"
  val rawhead: OMID = OMS(path)
  def apply(tpA: Term, tpB: Term, tpC: Term, vB: Term, vC: Term, vA: Term, f: Term, tC: Term): Term =
    ApplySpine(this.rawhead, tpA, tpB, tpC, vB, vC, vA, f, tC)
  def unapply(tm : Term) : Option[(Term, Term, Term, Term, Term, Term, Term, Term)] = tm match {
    case ApplySpine(this.rawhead, List(tpA, tpB, tpC, vB, vC, vA, f, tC)) => Some((tpA, tpB, tpC, vB, vC, vA, f, tC))
    case _ => None
  }
}

object StepUntilXRule extends ParametricRule {
  // todo: document
  private val DEBUG_MAX_ITER_COUNT = 100

  override def apply(controller: Controller, home: Term, args: List[Term]): Rule = args match {
    case List(truetm) =>
      RuleSet(StepUntilXComp(truetm))
    case _ =>
      ???
  }

  case class StepUntilXComp(truetm: Term) extends ComputationRule(StepUntilX.path) {
    override def applicable(t: Term): Boolean = t match {
      case StepUntilX(_, _, _, _, _, _, _, _) => true
      case _ => false
    }
    override def apply(check: CheckingCallback)(tm: Term, covered: Boolean)(implicit stack: Stack, history: History): Simplifiability = tm match {
      case StepUntilX(_, _, _, vB, vC, vA, f, tC) =>
        val values = mutable.ListBuffer[Term]()
        val valB = check.simplify(vB)
        val valC = check.simplify(vC)
        var iter = 0
        var cur = check.simplify(vA)
        while (iter < DEBUG_MAX_ITER_COUNT) {
          iter += 1
          check.tryToCheckWithoutDelay(Equality(stack, Apply(tC, cur), truetm, None)) match {
            case Some(true) =>
              return Simplify(Ls(values.toSeq : _*))
            case Some(false) | None =>
              values += cur
              cur = check.simplify(ApplySpine(f, valB, valC, cur))
          }
        }
        // todo: try to log something via stack, history, check, or this
        Simplify(Ls(values.toSeq : _*))
      case _ => Simplifiability.NoRecurse
    }
  }
}