//
// Course: BU CAS CS 520, Fall 2010
// Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Lecture on Thursday, Sep. 16, 2010
//

(* ****** ****** *)

#include "../MISC/BUCASCS520.hats"

(* ****** ****** *)

staload "lambda.sats"

(* ****** ****** *)

implement
fprint_term
  (out, t) = case+ t of
  | TMvar x => fprint_string (out, x)
  | TMlam (x, t1) => () where {
      val () = fprint_string (out, "TMlam(")
      val () = fprint_string (out, x)
      val () = fprint_string (out, ", ")
      val () = fprint_term (out, t1)
      val () = fprint_string (out, ")")
    } // end of [TMapp]
  | TMapp (t1, t2) => () where {
      val () = fprint_string (out, "TMapp(")
      val () = fprint_term (out, t1)
      val () = fprint_string (out, ", ")
      val () = fprint_term (out, t2)
      val () = fprint_string (out, ")")
    } // end of [TMapp]
// end of [fprint_term]

implement print_term (t) = fprint_term (stdout_ref, t)
implement prerr_term (t) = fprint_term (stderr_ref, t)

(* ****** ****** *)

implement I = TMlam ("x", TMvar "x")
implement K = TMlam ("x", TMlam ("y", TMvar "x"))
implement K' = TMlam ("x", TMlam ("y", TMvar "y"))
implement S = let
  val x = TMvar "x"
  val y = TMvar "y"
  val z = TMvar "z"
in
  TMlam ("x", TMlam ("y", TMlam ("z", TMapp (TMapp (x, z), TMapp (y, z)))))
end // end of [S]

(* ****** ****** *)

implement
size (t) = case+ t of
  | TMvar _ => 0
  | TMlam (_, t1) => 1 + size (t1)
  | TMapp (t1, t2) => 1 + size (t1) + size (t2)
// end of [size]

(* ****** ****** *)

implement
subterm (t, ps) = let
  fun err (): term = (
    prerr "subterm: illegal position"; prerr_newline (); exit (1)
  ) // end of [err]
in
  case+ ps of
  | list0_cons (p, ps) => (case+ t of
      | TMapp (t1, t2) => (
          if p = 0 then subterm (t1, ps)
          else if p = 1 then subterm (t2, ps)
          else err () // end of [if]
        ) : term // end of [TMapp]
      | TMlam (_, t1) => (
          if p = 0 then subterm (t1, ps)
          else err () // end of [if]
        ) : term // end of [TMlam]
      | TMvar _ => err ()
    ) // end of [list0_cons]
  | list0_nil () => t
end // end of [subterm]

(* ****** ****** *)

implement
subst0 (t, x, v) = case+ t of
  | TMvar x1 => if x = x1 then v else t
  | TMlam (x1, t1) => if x = x1 then t else TMlam (x1, subst0 (t1, x, v))
  | TMapp (t1, t2) => TMapp (subst0 (t1, x, v), subst0 (t2, x, v))
// end of [subst0]

(* ****** ****** *)

implement eval0 (t) = case+ t of
  | TMlam _ => t
  | TMapp (t1, t2) => let
      val t1 = eval0 (t1) in
      case+ t1 of
      | TMlam (x, t1_body) => eval0 (subst0 (t1_body, x, t2))
      | _ => TMapp (t1, t2)
    end // end of [TMapp]
  | TMvar _ => t // this cannot happen if [t] is closed!
// end of [eval0]

(* ****** ****** *)

val () = (print "I = "; print I; print_newline ())
val () = (print "K = "; print K; print_newline ())
val () = (print "K' = "; print K'; print_newline ())
val () = (print "S = "; print S; print_newline ())

(* ****** ****** *)
//
#define nil list0_nil
#define cons list0_cons
#define :: list0_cons
//
val _0001 = 0 :: 0 :: 0 :: 1 :: nil ()
val () = (print "S(0001) = "; print (subterm (S, _0001)); print_newline ())
//
val _00011 = 0 :: 0 :: 0 :: 1 :: 1 :: nil ()
val () = (print "S(00011) = "; print (subterm (S, _00011)); print_newline ())
//
(* ****** ****** *)

val () = assert_errmsg (size(K) = 2, #LOCATION)
val () = assert_errmsg (size(K') = 2, #LOCATION)
val () = assert_errmsg (size(S) = 6, #LOCATION)

(* ****** ****** *)

val SKK = TMapp (TMapp (S, K), K)
val () = let
  val t = eval0 (TMapp (SKK, TMvar "c"))
in
  print "eval0(SKK(c)) = "; print t; print_newline ()
end // end of [val]

(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [lambda.dats] *)