package info.kwarc.mmt.frameit.rules

import info.kwarc.mmt.api.checking._
import info.kwarc.mmt.api.{DPath, GlobalName, LocalName, ParametricRule, Rule, RuleSet}
import info.kwarc.mmt.api.frontend.Controller
import info.kwarc.mmt.api.objects._
import Conversions._
import info.kwarc.mmt.api.uom.{Simplifiability, Simplify}
import info.kwarc.mmt.lf.{Apply, Arrow, OfType, Pi}
import info.kwarc.mmt.LFX.datatypes.{Listtype, Ls}
import info.kwarc.mmt.api.utils.URI

import scala.collection.mutable

object StepUntil {
  val ns: DPath = DPath(URI.http colon "mathhub.info") / "FrameIT" / "frameworld"
  val path: GlobalName = ns ? "StepUntilRaw" ? "stepUntil"
  val term: OMID = OMS(path)
  def apply(tp: Term, iv: Term, f: Term, tC: Term): Term =
    OMA(this.term, List(tp, iv, f, tC))
  def unapply(tm : Term) : Option[(Term, Term, Term, Term)] = tm match {
    case OMA(this.term, List(tp, iv, f, tC)) => Some((tp, iv, f, tC))
    case _ => None
  }
}

object StepUntilRule extends ParametricRule {
  // todo: document
  private val DEBUG_MAX_ITER_COUNT = 100

  override def apply(controller: Controller, home: Term, args: List[Term]): Rule = args match {
    case List(truetm) =>
      RuleSet(StepUntilComp(truetm), StepUntilInfTypingRule(truetm))
    case _ =>
      ???
  }

  case class StepUntilComp(truetm: Term) extends ComputationRule(StepUntil.path) {
    override def apply(check: CheckingCallback)(tm: Term, covered: Boolean)(implicit stack: Stack, history: History): Simplifiability = tm match {
      case StepUntil(_, iv, f, tC) =>
        val values = mutable.ListBuffer[Term]()
        var iter = 0
        var cur = iv
        while (iter < DEBUG_MAX_ITER_COUNT) {
          iter += 1
          check.tryToCheckWithoutDelay(Equality(stack, Apply(tC, cur), truetm, None)) match {
            case Some(true) =>
              return Simplify(Ls(values.toSeq : _*))
            case Some(false) | None =>
              values += cur
              cur = check.simplify(Apply(f, cur))
          }
        }
        // todo: try to log something via stack, history, check, or this
        Simplify(Ls(values.toSeq : _*))
      case _ => Simplifiability.NoRecurse
    }
  }

  case class StepUntilInfTypingRule(truetm: Term) extends InferenceAndTypingRule(StepUntil.path, OfType.path) {
    override def apply(solver: Solver, tm: Term, tp: Option[Term], covered: Boolean)(implicit stack: Stack, history: History)
    : (Option[Term], Option[Boolean]) = tm match {
        case StepUntil(tp, iv, _, _)  =>
          solver.inferType(iv) match {
            case Some(tpA) =>
              (Some(Listtype(tpA)),Some(true))
            case _ =>
              (None, None)
          }
        case _ =>
          (None, None)
      }
      /**val (p, tpA) = (tm, tp) match {
        case (StepUntil(iv, f, tC), Some(Listtype(tp1))) =>
          solver.check(Typing(stack, iv, tp1, None))
          solver.inferType(f).map(solver.safeSimplifyUntil(_)(Arrow.unapply)._1) match {
            case Some(Pi(x,tpA1,tpB1)) if !tpB1.freeVars.contains(x) =>
              if (!covered) {
                solver.check(Subtyping(stack, iv, tpA1))
                solver.check(Typing(stack, iv, tpB1))
              }
              (tC, tp1)
            case _ =>
              return (None,None)
          }
        case (StepUntil(iv, f, tC), _) =>
          solver.inferType(f).map(solver.safeSimplifyUntil(_)(Arrow.unapply)._1) match {
            case Some(Pi(x,tpA1,tpB1)) if !tpB1.freeVars.contains(x) =>
              solver.check(Subtyping(stack, iv, tpA1))
              solver.check(Typing(stack, iv, tpB1))
              (tC, iv)
            case _ =>
              return (None,None)
          }
        case _ =>
          return (None, None)
      }
      if (!covered) {
        val proptp = solver.inferType(truetm) match {
          case Some(t) => t
          case _ => return (None, None)
        }
        val (name, _) = Context.pickFresh((stack ++ solver.constantContext ++ solver.outerContext).context, LocalName("ls"))
        solver.check(Typing(stack++name%tpA,Apply(p,OMV(name)),proptp))
      }
      (Some(Listtype(tpA)),Some(true))*/
  }
}