(*
** Some code used in the lecture
*)

(* ****** ****** *)

staload _ = "prelude/DATS/list.dats"
staload _ = "prelude/DATS/list_vt.dats"
staload _ = "prelude/DATS/list0.dats"

(* ****** ****** *)

staload "lambda.sats"

(* ****** ****** *)

#define nil list0_nil
#define cons list0_cons
#define :: list0_cons

(* ****** ****** *)

implement
fprint_term (out, t) =
  case+ t of
  | TERMvar (x) => fprintf (out, "TERMvar(%s)", @(x))
  | TERMlam (x, t_body) => let
      val () = fprint (out, "TERMlam(")
      val () = fprint (out, x)
      val () = fprint (out, ", ")
      val () = fprint (out, t_body)
      val () = fprint (out, ")")
    in
      // nothing
    end // end of [TERMlam]
  | TERMapp (t1, t2) => let
      val () = fprint (out, "TERMapp(")
      val () = fprint (out, t1)
      val () = fprint (out, ", ")
      val () = fprint (out, t2)
      val () = fprint (out, ")")
    in
      // nothing
    end // end of [TERMapp]
  | TERMint (i) => fprintf (out, "TERMint(%i)", @(i))
  | TERMopr (opr, _) => fprintf (out, "TERMopr(%s, ...)", @(opr))
// end of [fprint_term]

(* ****** ****** *)

(*
fun subst_list
  (ts1: termlst, x0: string, t2: term): termlst =
  case+ ts1 of
  | cons (t1, ts1) => cons (subst (t1, x0, t2), subst_list (ts1, x0, t2))
  | nil () => nil
// end of [subst_list]
*)

fun subst_list
  (ts1: termlst, x0: string, t2: term): termlst =
  list0_map_cloref (ts1, lam (t1) => subst (t1, x0, t2))

implement
subst (t1, x0, t2) =
  case+ t1 of
  | TERMvar x => if x0 = x then t2 else t1
  | TERMlam (x, t_body) =>
      if x0 = x then t1 else TERMlam (x, subst (t_body, x0, t2))
  | TERMapp (t11, t12) =>
      TERMapp (subst (t11, x0, t2), subst (t12, x0, t2))
  | TERMopr (opr, ts1) => TERMopr (opr, subst_list (ts1, x0, t2))
  | TERMint _ => t1
// end of [subst]

(* ****** ****** *)

exception UnboundVariable
exception UnsupportedOperator
exception TypeError

(* ****** ****** *)

fun eval_opr (opr: string, ts: termlst): term =
  case+ opr of
  | "+" => (case+ ts of
    | cons (TERMint i1, cons (TERMint i2, nil ())) => TERMint (i1+i2)
    | _ => $raise TypeError ()
    )
  | _ => $raise UnsupportedOperator ()
  
implement eval (t) =
  case+ t of
  | TERMlam _ => t
  | TERMapp (t1, t2) => let
      val t1 = eval (t1)
      val t2 = eval (t2) // call-by-value
    in
      case+ t1 of
      | TERMlam (x, t_body) => eval (subst (t_body, x, t2))
      | _ => $raise TypeError ()
    end // TERMapp
  | TERMint _ => t
  | TERMopr (opr, ts) => let
      val ts = list0_map_cloref (ts, lam t => eval t)
    in
      eval_opr (opr, ts)
    end
  | TERMvar x => $raise UnboundVariable ()
// end of [eval]

(* ****** ****** *)

val double = let
  val x = TERMvar "x" in
  TERMlam ("x", TERMopr ("+", cons (x, cons (x, nil))))
end

val twice = let
  val f = TERMvar "f" and x = TERMvar "x" in
  TERMlam ("f", TERMlam ("x", TERMapp (f, TERMapp (f, x))))
end

(* ****** ****** *)
//
val f1 = TERMapp (twice, double)
val ans = eval (TERMapp (f1, TERMint 1))
val () = (
  print "ans = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]
//
val f2 = TERMapp (TERMapp (twice, twice), double)
val ans = eval (TERMapp (f2, TERMint 1))
val () = (
  print "ans = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]
//
val f3 = TERMapp (TERMapp (TERMapp (twice, twice), twice), double)
val ans = eval (TERMapp (f3, TERMint 1))
val () = (
  print "ans = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]
//
(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [lambda.dats] *)