(*
** Some code used in the lecture
*)

(* ****** ****** *)

staload _ = "prelude/DATS/list.dats"
staload _ = "prelude/DATS/list_vt.dats"
staload _ = "prelude/DATS/list0.dats"

(* ****** ****** *)

staload "lambda.sats"

(* ****** ****** *)

#define nil list0_nil
#define cons list0_cons
#define :: list0_cons

(* ****** ****** *)

implement
fprint_term (out, t) =
  case+ t of
  | TERMvar (x) => fprintf (out, "TERMvar(%s)", @(x))
  | TERMlam (x, t_body) => let
      val () = fprint (out, "TERMlam(")
      val () = fprint (out, x)
      val () = fprint (out, ", ")
      val () = fprint (out, t_body)
      val () = fprint (out, ")")
    in
      // nothing
    end // end of [TERMlam]
  | TERMapp (t1, t2) => let
      val () = fprint (out, "TERMapp(")
      val () = fprint (out, t1)
      val () = fprint (out, ", ")
      val () = fprint (out, t2)
      val () = fprint (out, ")")
    in
      // nothing
    end // end of [TERMapp]
  | TERMint (i) => fprintf (out, "TERMint(%i)", @(i))
  | TERMopr (opr, _) => fprintf (out, "TERMopr(%s, ...)", @(opr))
// end of [fprint_term]

(* ****** ****** *)

implement
fprint_value (out, t) =
  case+ t of
  | VALclo _ => let
      val () = fprint (out, "VALclo(...)")
    in
      // nothing
    end // end of [VALclo]
  | VALint (i) => fprintf (out, "VALint(%i)", @(i))
// end of [fprint_value]


(* ****** ****** *)

exception UnboundVariable
exception UnsupportedOperator
exception TypeError

(* ****** ****** *)

fun eval_var
  (env: env, x0: string): value =
  case+ env of
  | list0_cons ((x, v), env) =>
      if x0 = x then v else eval_var (env, x0)
  | list0_nil () => $raise UnboundVariable ()
// end of [eval_var]

typedef valuelst = list0 (value)
fun eval_opr (opr: string, vs: valuelst): value =
  case+ opr of
  | "+" => (case+ vs of
    | cons (VALint i1, cons (VALint i2, nil ())) => VALint (i1+i2)
    | _ => $raise TypeError ()
    )
  | _ => $raise UnsupportedOperator ()

implement eval (env, t) =
  case+ t of
  | TERMlam (x, t_body) => VALclo (env, t)
  | TERMapp (t1, t2) => let
      val v1 = eval (env, t1)
      val v2 = eval (env, t2) // call-by-value
    in
      case+ v1 of
      | VALclo (env1, t) => let
          val- TERMlam (x, t_body) = t
          val env2 = list0_cons ((x, v2) , env1)
        in
          eval (env2, t_body)
        end
      | _ => $raise TypeError ()
    end // TERMapp
  | TERMint i => VALint (i)
  | TERMopr (opr, ts) => let
      val vs = list0_map_cloref (ts, lam t => eval (env, t))
    in
      eval_opr (opr, vs)
    end
  | TERMvar x => eval_var (env, x)
// end of [eval]

(* ****** ****** *)

implement eval0 (t) = eval (list0_nil, t)

(* ****** ****** *)

val double = let
  val x = TERMvar "x" in
  TERMlam ("x", TERMopr ("+", cons (x, cons (x, nil))))
end

val twice = let
  val f = TERMvar "f" and x = TERMvar "x" in
  TERMlam ("f", TERMlam ("x", TERMapp (f, TERMapp (f, x))))
end

(* ****** ****** *)
//
val f1 = TERMapp (twice, double)
val ans = eval0 (TERMapp (f1, TERMint 1))
val () = (
  print "ans = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]
//
val f2 = TERMapp (TERMapp (twice, twice), double)
val ans = eval0 (TERMapp (f2, TERMint 1))
val () = (
  print "ans = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]
//
val f3 = TERMapp (TERMapp (TERMapp (twice, twice), twice), double)
val ans = eval0 (TERMapp (f3, TERMint 1))
val () = (
  print "ans = "; fprint (stdout_ref, ans); print_newline ()
) // end of [val]
//
(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [lambda.dats] *)