//
// Course: BU CAS CS 520, Fall 2010
// Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Lecture on Tuesday, Sep. 28, 2010
//

(* ****** ****** *)

#include "../MISC/BUCASCS520.hats"

(* ****** ****** *)

staload "lambda.sats"

(* ****** ****** *)
//
#define nil list0_nil
#define cons list0_cons
#define :: list0_cons
//
(* ****** ****** *)

implement
fprint_term
  (out, t) = case+ t of
  | TMvar x => fprint_string (out, x)
  | TMlam (x, t1) => () where {
      val () = fprint_string (out, "TMlam(")
      val () = fprint_string (out, x)
      val () = fprint_string (out, ", ")
      val () = fprint_term (out, t1)
      val () = fprint_string (out, ")")
    } // end of [TMapp]
  | TMapp (t1, t2) => () where {
      val () = fprint_string (out, "TMapp(")
      val () = fprint_term (out, t1)
      val () = fprint_string (out, ", ")
      val () = fprint_term (out, t2)
      val () = fprint_string (out, ")")
    } // end of [TMapp]
  | TMint (i) => fprint_int (out, i)
  | TMbool (b) => fprint_bool (out, b)
  | TMopr (opr, ts) => () where {
      val () = fprint_string (out, "TMopr(")
      val () = fprint_string (out, opr)
      val () = fprint_string (out, "; ")
      val () = fprint_termlst (out, ts)
      val () = fprint_string (out, ")")
    } // end of [TMopr]
  | TMif (t1, t2, t3) => () where {
      val () = fprint_string (out, "TMif(")
      val () = fprint_term (out, t1)
      val () = fprint_string (out, ", ")
      val () = fprint_term (out, t2)
      val () = fprint_string (out, ", ")
      val () = fprint_term (out, t3)
      val () = fprint_string (out, ")")
    } // end of [TMif]
  | TMlet (x, t1, t2) => () where {
      val () = fprint_string (out, "TMlet(")
      val () = fprint_string (out, x)
      val () = fprint_string (out, ", ")
      val () = fprint_term (out, t1)
      val () = fprint_string (out, ", ")
      val () = fprint_term (out, t2)
      val () = fprint_string (out, ")")
    } // end of [TMlet]
// end of [fprint_term]

implement fprint_termlst
  (out, ts) = loop (out, ts, 0) where {
  fun loop (out: FILEref, ts: termlst, i: int): void =
    case+ ts of
    | list0_cons (t, ts) => let
        val () = if i > 0 then fprint_string (out, ", ")
        val () = fprint_term (out, t)
      in
        loop (out, ts, i+1)
      end // end of [list0_cons]
    | list0_nil () => ()
} // end of [fprint_termlst]

implement print_term (t) = fprint_term (stdout_ref, t)
implement prerr_term (t) = fprint_term (stderr_ref, t)

(* ****** ****** *)

fun subst0_lst
  (ts: termlst, x: string, v: term): termlst =
  case+ ts of
  | cons (t, ts) => cons (subst0 (t, x, v), subst0_lst (ts, x, v))
  | nil () => nil ()
// end of [subst0_lst]

implement
subst0 (t, x, v) = case+ t of
  | TMvar x1 => if x = x1 then v else t
  | TMlam (x1, t1) =>
      if x = x1 then t else TMlam (x1, subst0 (t1, x, v))
  | TMapp (t1, t2) => TMapp (subst0 (t1, x, v), subst0 (t2, x, v))
  | TMint _ => t
  | TMbool _ => t
  | TMopr (opr, ts) => TMopr (opr, subst0_lst (ts, x, v))
  | TMif (t1, t2, t3) =>
      TMif (subst0 (t1, x, v), subst0 (t2, x, v), subst0 (t3, x, v))
    // end of [TMif]
  | TMlet (x1, t1, t2) => let
      val t1 = subst0 (t1, x, v)
    in
      if x = x1 then TMlet (x1, t1, t2) else TMlet (x1, t1, subst0 (t2, x, v))
    end // end of [TMlet]
// end of [subst0]

(* ****** ****** *)

fun eval0_opr
  (opr: string, ts: termlst): term = let
  val ts = list0_map_fun (ts, eval0)
in
  case+ (opr) of
  | "+" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMint (i1+i2)
    end // end of ["+"]
  | "-" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMint (i1-i2)
    end // end of ["-"]
  | "*" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMint (i1*i2)
    end // end of ["*"]
  | "/" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMint (i1/i2)
    end // end of ["/"]
  | "<" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMbool (i1 < i2)
    end // end of ["<"]
  | "<=" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMbool (i1 <= i2)
    end // end of ["<="]
  | ">" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMbool (i1 > i2)
    end // end of [">"]
  | ">=" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMbool (i1 >= i2)
    end // end of [">="]
  | "=" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMbool (i1 = i2)
    end // end of ["="]
  | "<>" => let
      val- TMint i1 :: TMint i2 :: _ = ts in TMbool (i1 <> i2)
    end // end of ["<>"]
  | _ => (
      prerr "eval0_opr: unrecognized operator: opr = "; prerr opr; prerr_newline (); exit(1)
    ) // end of [_]
end // end of [eval0_opr]

implement eval0 (t) = case+ t of
  | TMlam _ => t
  | TMapp (t1, t2) => let
      val t1 = eval0 (t1) in
      case+ t1 of
      | TMlam (x, t1_body) => eval0 (subst0 (t1_body, x, t2))
      | _ => TMapp (t1, t2)
    end // end of [TMapp]
  | TMvar _ => t // this cannot happen if [t] is closed!
  | TMint _ => t
  | TMbool _ => t
  | TMopr (opr, ts) => eval0_opr (opr, ts)
  | TMif (t1, t2, t3) => let
      val- TMbool b = eval0 (t1) in if b then t3 else t3
    end // end of [TMif]
  | TMlet (x, t1, t2) => let
      val v1 = eval0 (t1) in eval0 (subst0 (t2, x, v1))
    end // end of [TElet]
// end of [eval0]

(* ****** ****** *)

val izero = TMint (0)
val isucc = TMlam ("x", TMopr ("+", TMvar("x") :: TMint(1) :: nil))

(* ****** ****** *)

implement
genNumeral (n) = let
  val f = TMvar "f" and x = TMvar "x"
  fun loop (n: int):<cloref1> term =
    if n > 0 then TMapp (f, loop (n-1)) else x
  // end of [loop]
in
  TMlam ("f", TMlam ("x", loop (n)))
end // end of genNumeral]

(* ****** ****** *)

implement
print_numeral (n) = let
  val nsz = TMapp (TMapp (n, isucc), izero)
  val- TMint n = eval0 (nsz)
in
  printf ("\"%i\"", @(n))
end // end of [print_numeral]

(* ****** ****** *)

val _0 = genNumeral (0)
val _1 = genNumeral (1)
val _2 = genNumeral (2)
val () = (print "_2 = "; print_numeral _2; print_newline ())
val _3 = genNumeral (3)
val () = (print "_3 = "; print_numeral _3; print_newline ())

(* ****** ****** *)
//
// succ = \n.\f.\x. n(f)(f(x))
//
val succ = let
  val n = TMvar "n"
  val f = TMvar "f" and x = TMvar "x"
in
  TMlam ("n", TMlam ("f", TMlam ("x", TMapp (TMapp (n, f), TMapp (f, x)))))
end // end of [succ]

//
// plus = \m.\n. m(succ)(n)
//
val plus = let
  val m = TMvar "m"
  val n = TMvar "n" in
  TMlam ("m", TMlam ("n", TMapp (TMapp (m, succ), n)))
end // end of [plus]

//
// times = \m.\n. m(plus(n))(_0)
//
val times = let
  val m = TMvar "m"
  val n = TMvar "n"
  val plusn = TMapp (plus, n)
in
  TMlam ("m", TMlam ("n", TMapp (TMapp (m, plusn), _0)))
end // end of [times]

(* ****** ****** *)

//
// power = \m.\n. n(times(m))(_1) // m^n
//
val power = let
  val m = TMvar "m"
  val n = TMvar "n"
  val plusm = TMapp (times, m)
in
  TMlam ("m", TMlam ("n", TMapp (TMapp (n, plusm), _1)))
end // end of [power]

(* ****** ****** *)

val () = gc_chunk_count_limit_set (1 << 15)
val () = gc_chunk_count_limit_max_set (~1) // no max

val _5 = eval0 (TMapp (TMapp (plus, _2), _3))
val () = (print "_5 = "; print_numeral _5; print_newline ())
val _25 = eval0 (TMapp (TMapp (times, _5), _5))
val () = (print "_25 = "; print_numeral _25; print_newline ())
val _10 = genNumeral (10)
val _1024 = eval0 (TMapp (TMapp (power, _2), _10))
val () = (print "_1024 = "; print_numeral _1024; print_newline ())
val _100 = eval0 (TMapp (TMapp (power, _10), _2))
val () = (print "_100 = "; print_numeral _100; print_newline ())
val _3_5 = eval0 (TMapp (TMapp (power, _3), _5))
val () = (print "_3_5 = "; print_numeral _3_5; print_newline ())

(* ****** ****** *)

val _20 = let
  val x = TMvar "x" in
  eval0 (TMlet ("x", _10, TMapp (TMapp (plus, x), x)))
end // end of [_20]
val () = (print "_20 = "; print_numeral _20; print_newline ())

(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [lambda.dats] *)