(*
** Some code used in the lecture
*)

(* ****** ****** *)

staload _ = "prelude/DATS/list.dats"
staload _ = "prelude/DATS/list_vt.dats"
staload _ = "prelude/DATS/list0.dats"

staload _ = "prelude/DATS/reference.dats"

(* ****** ****** *)

staload "lambda.sats"

(* ****** ****** *)

#define nil list0_nil
#define cons list0_cons
#define :: list0_cons

(* ****** ****** *)

implement
fprint_term (out, t) =
  case+ t of
  | TERMvar (x) => fprintf (out, "TERMvar(%s)", @(x))
  | TERMlam (x, t_body) => let
      val () = fprint (out, "TERMlam(")
      val () = fprint (out, x)
      val () = fprint (out, ", ")
      val () = fprint (out, t_body)
      val () = fprint (out, ")")
    in
      // nothing
    end // end of [TERMlam]
  | TERMapp (t1, t2) => let
      val () = fprint (out, "TERMapp(")
      val () = fprint (out, t1)
      val () = fprint (out, ", ")
      val () = fprint (out, t2)
      val () = fprint (out, ")")
    in
      // nothing
    end // end of [TERMapp]
  | TERMint (i) => fprintf (out, "TERMint(%i)", @(i))
  | TERMopr (opr, _) => fprintf (out, "TERMopr(%s, ...)", @(opr))
  | TERMif _ => fprintf (out, "TERMif(...)", @())
  | TERMfix _ => fprintf (out, "TERMfix(...)", @())
// end of [fprint_term]

(* ****** ****** *)

implement
fprint_value (out, t) =
  case+ t of
  | VALclo _ => let
      val () = fprint (out, "VALclo(...)")
    in
      // nothing
    end // end of [VALclo]
  | VALint (i) =>
      fprintf (out, "VALint(%i)", @(i))
  | VALbool (b) => begin
      fprint (out, "VALbool("); fprint (out, b); fprint (out, ")")
    end // end of [VALbool]
  | VALref r => let
      val () = fprint (out, "VALref(...)")
    in
      // nothing
    end // end of [VALref]
// end of [fprint_value]


(* ****** ****** *)

exception UnboundVariable
exception UnsupportedOperator
exception TypeError

(* ****** ****** *)

fun eval_var
  (env: env, x0: string): value =
  case+ env of
  | list0_cons ((x, v), env) =>
      if x0 = x then v else eval_var (env, x0)
  | list0_nil () => $raise UnboundVariable ()
// end of [eval_var]

typedef valuelst = list0 (value)
fun eval_opr (opr: string, vs: valuelst): value =
  case+ opr of
//
  | "+" => (case+ vs of
    | cons (VALint i1, cons (VALint i2, nil ())) => VALint (i1+i2)
    | _ => $raise TypeError ()
    )
//
  | "-" => (case+ vs of
    | cons (VALint i1, cons (VALint i2, nil ())) => VALint (i1-i2)
    | _ => $raise TypeError ()
    )
//
  | "*" => (case+ vs of
    | cons (VALint i1, cons (VALint i2, nil ())) => VALint (i1*i2)
    | _ => $raise TypeError ()
    )
//
  | "/" => (case+ vs of
    | cons (VALint i1, cons (VALint i2, nil ())) => VALint (i1/i2)
    | _ => $raise TypeError ()
    )
//
  | "<" => (case+ vs of
    | cons (VALint i1, cons (VALint i2, nil ())) => VALbool (i1 < i2)
    | _ => $raise TypeError ()
    )
//
  | "<=" => (case+ vs of
    | cons (VALint i1, cons (VALint i2, nil ())) => VALbool (i1 <= i2)
    | _ => $raise TypeError ()
    )
//
  | ">" => (case+ vs of
    | cons (VALint i1, cons (VALint i2, nil ())) => VALbool (i1 > i2)
    | _ => $raise TypeError ()
    )
//
  | ">=" => (case+ vs of
    | cons (VALint i1, cons (VALint i2, nil ())) => VALbool (i1 >= i2)
    | _ => $raise TypeError ()
    )
//
  | _ => $raise UnsupportedOperator ()
// end of [eval_opr]

(* ****** ****** *)

fun eval_appclo (
  env1: env, t_lam: term, v: value
) : value = let
  val- TERMlam
    (x, t_body) = t_lam
  val env2 = (x, v) :: env1
in
  eval (env2, t_body)
end // end of [eval_appclo]

(* ****** ****** *)

implement eval (env, t) =
  case+ t of
  | TERMvar x => eval_var (env, x)
  | TERMlam (x, t_body) => VALclo (env, t)
  | TERMapp (t1, t2) => let
      val v1 = eval (env, t1)
      val v2 = eval (env, t2) // call-by-value
    in
      case+ v1 of
      | VALclo (env1, t_lam) => eval_appclo (env1, t_lam, v2)
      | VALref r => let
          val- VALclo (env1, t_lam) = !r in eval_appclo (env1, t_lam, v2)
        end // end of [VALref]
      | _ => $raise TypeError ()
    end // TERMapp
  | TERMint i => VALint (i)
  | TERMopr (opr, ts) => let
      val vs = list0_map_cloref (ts, lam t => eval (env, t))
    in
      eval_opr (opr, vs)
    end // end of [TERMopr]
  | TERMif (t1, t2, t3) => let
      val- VALbool (b) = eval (env, t1)
    in
      eval (env, if b then t2 else t3)
    end // end of [TERMif]
  | TERMfix (f, x, t_body) => let
      val r = ref<value> (VALint 0)
      val env = list0_cons ((f, VALref r), env)
      val v_clo = VALclo (env, TERMlam (x, t_body))
      val () = !r := v_clo
    in
      v_clo
    end // end of [TERMfix]
// end of [eval]

(* ****** ****** *)

implement eval0 (t) = eval (list0_nil, t)

(* ****** ****** *)

val double = let
  val x = TERMvar "x" in
  TERMlam ("x", TERMopr ("+", cons (x, cons (x, nil))))
end

val twice = let
  val f = TERMvar "f" and x = TERMvar "x" in
  TERMlam ("f", TERMlam ("x", TERMapp (f, TERMapp (f, x))))
end

(* ****** ****** *)
//
val f1 = TERMapp (twice, double)
val ans = eval0 (TERMapp (f1, TERMint 1))
val () = (
  print "ans = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]
//
val f2 = TERMapp (TERMapp (twice, twice), double)
val ans = eval0 (TERMapp (f2, TERMint 1))
val () = (
  print "ans = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]
//
val f3 = TERMapp (TERMapp (TERMapp (twice, twice), twice), double)
val ans = eval0 (TERMapp (f3, TERMint 1))
val () = (
  print "ans = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]
//
(* ****** ****** *)

val fact = let
  val f = TERMvar "f"
  val x = TERMvar "x"
  val _1 = TERMint (1)
  macdef binop (opr, x1, x2) = TERMopr (,(opr), ,(x1) :: ,(x2) :: nil)
in
  TERMfix ("f", "x",
    TERMif (binop (">=", x, _1), binop ("*", x, TERMapp (f, binop("-", x, _1))), _1)
  ) // end of [TERMfix]
end // end of [fact]

val ans = eval0 (TERMapp (fact, TERMint 10))
val () = (
  print "fact(10) = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]

(* ****** ****** *)

val f91 = let
  val f = TERMvar "f"
  val n = TERMvar "n"
  val _10 = TERMint (10)
  val _11 = TERMint (11)
  val _101 = TERMint (101)
  macdef binop (opr, x1, x2) = TERMopr (,(opr), ,(x1) :: ,(x2) :: nil)
in
  TERMfix ("f", "n",
    TERMif (
      binop (">=", n, _101)
    , binop ("-", n, _10)
    , TERMapp (f, TERMapp (f, binop ("+", n, _11)))
    ) // end of [TERMif]
  ) // end of [TERMfix]
end // end of [f91]

val ans = eval0 (TERMapp (f91, TERMint 10))
val () = (
  print "f91(10) = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]

(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [lambda.dats] *)