(*
** Course: Concepts of Programming Languages (BU CAS CS 320)
** Semester: Summer I, 2009
** Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
*)

//
// Author: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Time: June, 2009
//

(* ****** ****** *)

#include "BUCASCS320.hats"

(* ****** ****** *)

//
// abstract syntax trees for types in STFPL
//

datatype ty =
  | TYbase of string(*name*)
  | TYtup of (ty, ty)
  | TYfun of (ty, ty)
// end of [datatype ty]

typedef tylst = list0 ty

(* ****** ****** *)

val ty_int = TYbase "int"
val ty_bool = TYbase "bool"

(* ****** ****** *)

fun eq_ty_ty (ty1: ty, ty2: ty): bool =
  case+ (ty1, ty2) of
  | (TYbase name1, TYbase name2) => name1 = name2
  | (TYtup (ty11, ty12), TYtup (ty21, ty22)) =>
      eq_ty_ty (ty11, ty21) andalso eq_ty_ty (ty12, ty22)
    // ...
  | (TYfun (ty11, ty12), TYfun (ty21, ty22)) =>
      eq_ty_ty (ty21, ty11) andalso eq_ty_ty (ty12, ty22)
    // ...
  | (_, _) => false
// end of [eq_ty_ty]

fun fprint_ty
  (out: FILEref, ty: ty) = case+ ty of
  | TYbase name => fprint_string (out, name)
  | TYtup (ty1, ty2) => begin
      fprint_string (out, "TYtup("); 
      fprint_ty (out, ty1); 
      fprint_string (out, ", ");
      fprint_ty (out, ty2); 
      fprint_string (out, ")"); 
    end // end of [TYtup]
  | TYfun (ty1, ty2) => begin
      fprint_string (out, "TYfun("); 
      fprint_ty (out, ty1); 
      fprint_string (out, ", ");
      fprint_ty (out, ty2); 
      fprint_string (out, ")"); 
    end // end of [TYfun]
// end of [fprint_ty]

(* ****** ****** *)

#define :: list0_cons
#define cons list0_cons
#define nil list0_nil

typedef ConstTypeMap_t =
  list0 @(string(*name*), tylst, ty)
val
rec theConstTypeMap : ConstTypeMap_t =
  ("+", ty_int :: ty_int :: nil, ty_int) ::
  ("-", ty_int :: ty_int :: nil, ty_int) ::
  ("*", ty_int :: ty_int :: nil, ty_int) ::
  ("/", ty_int :: ty_int :: nil, ty_int) ::
  (">", ty_int :: ty_int :: nil, ty_bool) ::
  (">=", ty_int :: ty_int :: nil, ty_bool) ::
  ("<", ty_int :: ty_int :: nil, ty_bool) ::
  ("<=", ty_int :: ty_int :: nil, ty_bool) ::
  ("=", ty_int :: ty_int :: nil, ty_bool) ::
  ("<>", ty_int :: ty_int :: nil, ty_bool) ::
  nil ()
// end of [val]

(* ****** ****** *)

exception IllTyped of ()
exception UnboundVariable of ()
exception UnknownConstant of ()
exception ArityError of ()

(* ****** ****** *)

fun ofType_const
  (c0: string): @(tylst, ty) = let
  typedef map_t = ConstTypeMap_t
  fun aux (c0: string, xs: map_t): @(tylst, ty) =
    case+ xs of
    | list0_cons (x, xs) =>
        if (x.0 = c0) then @(x.1, x.2) else aux (c0, xs)
    | list0_nil () => $raise UnknownConstant ()
  // end of [aux]
in
  aux (c0, theConstTypeMap)
end // end of [ofType_const]

(* ****** ****** *)

//
// abstract syntax trees for terms in STFPL
//

typedef v1ar = string

datatype term =
  | TERMbool of bool
  | TERMint of int
  | TERMvar of v1ar
  | TERMcst of (string(*name*), termlst)
  | TERMtup of (term, term)
  | TERMfst of term
  | TERMsnd of term
(*
  | TERMlam of (string(*arg*), term(*body*))
*)
  // typing a la Church
  | TERMlam of (v1ar(*arg*), ty(*argty*), term(*body*))  
  | TERMapp of (term, term)
  | TERMif of (term(*cond*), term(*then*), term(*else*))
  | TERMfix of (
      v1ar(*fun*)
    , v1ar(*arg*), ty(*argty*), term(*body*), ty(*resty*)
    )
// end of [datatype term]  
  
where termlst = list0 term

typedef ctx = list0 @(v1ar, ty)

extern fun ofType (Gamma: ctx, tm: term): ty

fun ofType_var (Gamma: ctx, x: v1ar): ty =
  case+ Gamma of
  | list0_cons (xty, Gamma) =>
      if xty.0 = x then xty.1 else ofType_var (Gamma, x)
    // end of [list0_cons]  
  | list0_nil () => let
      val () = begin
        prerr "The variable ["; prerr x; prerr "] is unbound.\n"
      end
    in
      $raise UnboundVariable ()
    end (* end of [list0_nil] *)
// end of [ofType_var]

implement ofType (Gamma, tm0) = case+ tm0 of
  | TERMbool _ => ty_bool
  | TERMint _ => ty_int
  | TERMvar x => ofType_var (Gamma, x)
  | TERMcst (c, tms) => let
      val @(tys, ty) = ofType_const (c)
      val () = loop (tms, tys) where {
        fun loop
          (tms: termlst, tys: tylst):<cloref1> void =
          case+ (tms, tys) of
          | (list0_cons (tm, tms),
             list0_cons (ty, tys)) => let
              val ty1 = ofType (Gamma, tm)
              val ans = eq_ty_ty (ty1, ty)
              val () = (case+ 0 of
                | _ when ans => () | _ (*ty<>ty1*) => $raise IllTyped ()
              ) : void // end of [val]
            in
              loop (tms, tys)
            end // end of [list0_cons, list0_cons]  
          | (list0_nil (), list0_nil ()) => ()
          | (_, _) => $raise ArityError () 
      } (* end of [val] *)
    in
      ty
    end // end of [TERMcst]  
  | TERMtup (tm1, tm2) => let
      val ty1 = ofType (Gamma, tm1) and ty2 = ofType (Gamma, tm2)
    in
      TYtup (ty1, ty2)
    end // end of [TERMtup]  
  | TERMfst (tm) => let
      val ty = ofType (Gamma, tm)
    in
      case+ ty of
      | TYtup (ty1, _) => ty1
      | _ => $raise IllTyped ()
    end // end of [TERMfst]
  | TERMsnd (tm) => let
      val ty = ofType (Gamma, tm)
    in
      case+ ty of
      | TYtup (_, ty2) => ty2
      | _ => $raise IllTyped ()
    end // end of [TERMsnd]
  | TERMlam (x, ty_arg, tm_body) => let
      val Gamma = list0_cons ((x, ty_arg), Gamma)
      val ty_res = ofType (Gamma, tm_body)
    in
      TYfun (ty_arg, ty_res)
    end // end of [TERMlam]
  | TERMapp (t_fun, t_arg) => let
      val ty_fun = ofType (Gamma, t_fun)
      val ty_arg = ofType (Gamma, t_arg)
    in
      case+ ty_fun of
      | TYfun (ty1_arg, ty_res) => let
          val ans = eq_ty_ty (ty_arg, ty1_arg)
        in
          case+ 0 of
          | _ when ans => ty_res
          | _ (* ans = false *) => $raise IllTyped ()
        end  
      | _ => $raise IllTyped ()
    end // end of [TERMapp]
  | TERMif (tm1, tm2, tm3) => let
      val ty1 = ofType (Gamma, tm1)
    in
      if eq_ty_ty (ty1, ty_bool) then let
        val ty2 = ofType (Gamma, tm2)
        val ty3 = ofType (Gamma, tm3)
        val ans = eq_ty_ty (ty2, ty3)
      in
        if ans then ty2 else $raise IllTyped ()
      end else $raise IllTyped ()
    end // end of [TERMif]
  | TERMfix ( // fix f(x).t
      f, x, ty_arg, tm, ty_res
    ) => let
      val ty_fun = TYfun (ty_arg, ty_res)
      val Gamma = list0_cons ((f, ty_fun), Gamma)
      val Gamma = list0_cons ((x, ty_arg), Gamma)
      val ty1_res = ofType (Gamma, tm)
      val ans = eq_ty_ty (ty_res, ty1_res)
    in
      case+ 0 of
      | _ when ans => ty_fun
      | _ (* ty_res <> ty1_res *) => $raise IllTyped ()
    end // end of [TERMfix]
(*
  | _ => exit (1)
*)

(* ****** ****** *)

// id = lam x. x
val id_tm = TERMlam ("x", ty_int, TERMvar "x")
val id_ty = ofType (list0_nil, id_tm)
val () = fprint_ty (stdout_ref, id_ty)
val () = print_newline ()

// double = lam x. x + x
val double_tm = TERMlam
  ("x", ty_int, TERMcst ("+", TERMvar "x" :: TERMvar "x" :: nil ()))
// end of [val]
val double_ty = ofType (list0_nil, double_tm)
val () = fprint_ty (stdout_ref, double_ty)
val () = print_newline ()

// K = lam x. lam y. x 
val K_tm = TERMlam ("x", ty_int, TERMlam ("y", ty_int, TERMvar "x"))
val K_ty = ofType (list0_nil, K_tm)
val () = fprint_ty (stdout_ref, K_ty)
val () = print_newline ()

// S = lam x. lam y. lam z. x (z) (y (z))
val S_tm = TERMlam (
  "x" // : int -> (int -> int) 
, TYfun (ty_int, TYfun (ty_int, ty_int))
, TERMlam (
    "y" // : int -> int 
  , TYfun (ty_int, ty_int)
  , TERMlam ("z", ty_int, app (app (x, z), app (y, z)))
  )
) where {
  macdef app (t1, t2) = TERMapp (,(t1), ,(t2))
  val x = TERMvar "x" and y = TERMvar "y" and z = TERMvar "z"
} // end of [val]
val S_ty = ofType (list0_nil, S_tm)
val () = fprint_ty (stdout_ref, S_ty)
val () = print_newline ()

val tm_illtyped = TERMapp (id_tm, TERMbool true)
val _ = try
  let val _ = ofType (list0_nil, tm_illtyped) in () end
with
  | ~IllTyped () => (prerr "Illtyped"; prerr_newline ())
// end of [val]

val tm_illtyped = TERMapp (double_tm, TERMbool true)
val _ = try
  let val _ = ofType (list0_nil, tm_illtyped) in () end
with
  | ~IllTyped () => (prerr "Illtyped"; prerr_newline ())
// end of [val]

(* ****** ****** *)

exception EVAL_UnboundVariable of ()
exception EVAL_Illtyped of ()

extern fun eval (tm: term): term

fun subst (sub: term, x: v1ar, tm0: term): term =
  case+ tm0 of
  | TERMbool _ => tm0
  | TERMint _ => tm0
  | TERMvar x1 => if x = x1 then sub else tm0
  | TERMcst (cst, tms_arg) => let
      val tms_arg = list0_map_cloref<term><term>
        (tms_arg, lam (tm) => subst (sub, x, tm))
    in
      TERMcst (cst, tms_arg)
    end // end of [TERMtup]
  | TERMtup (tm1, tm2) => TERMtup
      (subst (sub, x, tm1), subst (sub, x, tm2))
  | TERMfst (tm) => TERMfst (subst (sub, x, tm))
  | TERMsnd (tm) => TERMsnd (subst (sub, x, tm))
  | TERMlam (x1, ty, tm_body) => begin
      if x <> x1 then TERMlam (x1, ty, subst (sub, x, tm_body))
                 else tm0 
    end // end of [TERMlam]
  | TERMapp (tm1, tm2) => TERMapp
      (subst (sub, x, tm1), subst (sub, x, tm2))
  | TERMif (tm1, tm2, tm3) => TERMif
      (subst (sub, x, tm1), subst (sub, x, tm2), subst (sub, x, tm3))
  | TERMfix (f1, x1, ty_arg, tm_body, ty_res) => begin
      if (x <> f1) then
        if (x <> x1) then
          TERMfix (f1, x1, ty_arg, subst (sub, x, tm_body), ty_res)
        else tm0
      else tm0
    end // end of [TERMfix]
// end of [subst]

fun eval_cst (cst: string, tms: termlst): term =
  case+ cst of
  | "+" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMint (i1+i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | "-" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMint (i1-i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | "*" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMint (i1*i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | "/" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMint (i1/i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | ">" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMbool (i1 > i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | ">=" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMbool (i1 >= i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | "<" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMbool (i1 < i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | "<=" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMbool (i1 <= i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | "=" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMbool (i1 = i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | "<>" => begin case+ tms of
    | TERMint i1 :: TERMint i2 :: nil () => TERMbool (i1 <> i2)
    | _ => $raise EVAL_Illtyped ()
    end
  | _ => begin
      prerrf ("The constant [%s] is not supported.\n", @(cst));
      exit (1) ;
    end // end of [_]
// end of [eval_cst]

implement eval (tm0) = case+ tm0 of
  | TERMbool _ => tm0
  | TERMint _ => tm0
  | TERMvar _ => $raise EVAL_UnboundVariable ()
  | TERMcst (cst, tms_arg) => let
      val tms_arg = list0_map_fun (tms_arg, eval)
    in
      eval_cst (cst, tms_arg)
    end // end of [TERMcst]  
  | TERMtup (tm1, tm2) => TERMtup (eval tm1, eval tm2)
  | TERMfst (tm) => let
      val tm = eval (tm) in
      case+ tm of
      | TERMtup (tm1, _) => tm1
      | _ => $raise EVAL_Illtyped ()
    end // end of [TERMfst]
  | TERMsnd (tm) => let
      val tm = eval (tm) in
      case+ tm of
      | TERMtup (_, tm2) => tm2
      | _ => $raise EVAL_Illtyped ()
    end // end of [TERMsnd]
  | TERMlam _ => tm0
  | TERMapp (tm_fun, tm_arg) => let
      val tm_fun = eval tm_fun in
      case+ tm_fun of
      | TERMlam (x, _(*ty*), tm_body) => let
          val tm_arg = eval tm_arg in eval (subst (tm_arg, x, tm_body))
        end // end of [TERMlam]
      | _ => $raise EVAL_Illtyped ()
    end (* end of [TERMapp] *)
  | TERMif (tm1, tm2, tm3) => let
      val tm1 = eval tm1 in
      case+ tm1 of
      | TERMbool b => if b then eval tm2 else eval tm3
      | _ => $raise EVAL_Illtyped ()
    end // end of [TERMif]   
  | TERMfix (f, x, ty_arg, tm_body, _(*res*)) =>
      TERMlam (x, ty_arg, subst (tm0, f, tm_body))
// end of [eval]

// id = lam x. x
val id_tm = TERMlam ("x", ty_int, TERMvar "x")
val id_ty = ofType (list0_nil, id_tm)
val () = fprint_ty (stdout_ref, id_ty)
val () = print_newline ()

// double = lam x. x + x
val double_tm = TERMlam
  ("x", ty_int, TERMcst ("+", TERMvar "x" :: TERMvar "x" :: nil ()))
// end of [val]
val double_ty = ofType (list0_nil, double_tm)
val () = fprint_ty (stdout_ref, double_ty)
val () = print_newline ()

// K = lam x. lam y. x 
val K_tm = TERMlam ("x", ty_int, TERMlam ("y", ty_int, TERMvar "x"))
val K_ty = ofType (list0_nil, K_tm)
val () = fprint_ty (stdout_ref, K_ty)
val () = print_newline ()

// S = lam x. lam y. lam z. x (z) (y (z))
val S_tm = TERMlam (
  "x" // : int -> (int -> int) 
, TYfun (ty_int, TYfun (ty_int, ty_int))
, TERMlam (
    "y" // : int -> int 
  , TYfun (ty_int, ty_int)
  , TERMlam ("z", ty_int, app (app (x, z), app (y, z)))
  )
) where {
  macdef app (t1, t2) = TERMapp (,(t1), ,(t2))
  val x = TERMvar "x" and y = TERMvar "y" and z = TERMvar "z"
} // end of [val]
val S_ty = ofType (list0_nil, S_tm)
val () = fprint_ty (stdout_ref, S_ty)
val () = print_newline ()

val tm_illtyped = TERMapp (id_tm, TERMbool true)
val _ = try
  let val _ = ofType (list0_nil, tm_illtyped) in () end
with
  | ~IllTyped () => (prerr "Illtyped"; prerr_newline ())
// end of [val]

val tm_illtyped = TERMapp (double_tm, TERMbool true)
val _ = try
  let val _ = ofType (list0_nil, tm_illtyped) in () end
with
  | ~IllTyped () => (prerr "Illtyped"; prerr_newline ())
// end of [val]

(* ****** ****** *)

val fact_tm = TERMfix ("f", "x", ty_int, tm_body, ty_int) where {
  val f = TERMvar "f"
  val x = TERMvar "x"
  val _0 = TERMint 0
  val _1 = TERMint 1
  val x1 = TERMcst ("-", x :: _1 :: nil)
  val tm_cond = TERMcst (">", x :: _0 :: nil)
  val tm_then = TERMcst ("*", x :: TERMapp (f, x1) :: nil)
  val tm_else = _1
  val tm_body = TERMif (tm_cond, tm_then, tm_else)
} (* end of [val] *)

val fact_ty = ofType (list0_nil, fact_tm)
val () = print "fact_ty = "
val () = fprint_ty (stdout_ref, fact_ty)
val () = print_newline ()

(* ****** ****** *)

val plus_1_1 = TERMcst ("+", TERMint 1 :: TERMint 1 :: nil)
val () = case+ eval plus_1_1 of
  | TERMint n => printf ("eval (plus_1_1) = %i\n", @(n))
  | _ => exit (1) 
// end of [val]

val double_10 = TERMapp (double_tm, TERMint 10) 
val () = case+ eval double_10 of
  | TERMint n => printf ("eval (double_10) = %i\n", @(n))
  | _ => (prerr "double_10\n"; exit (1))
// end of [val]

(* ****** ****** *)

val fact_5 = TERMapp (fact_tm, TERMint 5)
val () = case+ eval fact_5 of
  | TERMint n => printf ("eval (fact_5) = %i\n", @(n))
  | _ => (prerr "fact_5\n"; exit (1))
// end of [val]

(* ****** ****** *)

val f91_tm = TERMfix ("f", "x", ty_int, tm_body, ty_int) where {
  val x = TERMvar "x"
  val f = TERMvar "f"
  val tm1 = TERMcst (">=", x :: TERMint 101 :: nil)
  val tm2 = TERMcst ("-", x :: TERMint 10 :: nil)
  val tm3 = TERMapp (f, TERMapp (f, TERMcst ("+", x :: TERMint 11 :: nil))) 
  val tm_body = TERMif (tm1, tm2, tm3)
} // end of [val f91_tm]

val f91_ty = ofType (list0_nil, f91_tm)

val f91_30 = TERMapp (f91_tm, TERMint 30)
val () = case+ eval f91_30 of
  | TERMint n => printf ("eval (f91_30) = %i\n", @(n))
  | _ => (prerr "f91_30\n"; exit (1))
// end of [val]

(* ****** ****** *)
  
implement main () = ()

(* ****** ****** *)

(* end of [STFPL.dats] *)