// //////////////////////////////////////////////////////////////////////////
// Andrei Lapets // with a few minor modifications by Hongwei Xi
//
// CS 520, Assignment #2, Exercise #7
//
// Fall 2008
//
// //////////////////////////////////////////////////////////////////////////

//
// How to compile:
//   atscc -o test almf_subst.dats
// How to test:
//   ./test
//

staload _(*anonymous*) = "prelude/DATS/list0.dats"

#define nil list0_nil
#define :: list0_cons

datatype term =
  TmVar of string | TmLam of (string, term) | TmApp of (term, term)

datatype term1 =
  TmVar1 of string | TmInd1 of int | TmLam1 of term1 | TmApp1 of (term1, term1)

fun printTerm (t:term): void_t0ype = case+ t of
 | TmVar s       => print s
 | TmLam (s,t)   => (print "("; print "\\"; print s; print "."; printTerm t; print ")")
 | TmApp (t1,t2) => (printTerm t1; print " "; printTerm t2)

fun printTerm1 (t:term1): void_t0ype = case+ t of
 | TmVar1 s       => print s
 | TmInd1 i       => printf ("%i", @(i))
 | TmLam1 t       => (print "("; print "\\"; printTerm1 t; print ")")
 | TmApp1 (t1,t2) => (printTerm1 t1; print " "; printTerm1 t2)

typedef name = string
typedef set = list0 name

fun index (l: set, v: name, i: int): term1 = case+ l of
 | v'::l  => if v = v' then TmInd1 i else index (l, v, i+1)
 | nil () => TmVar1 v

fun nf_alpha (g: set, t: term): term1 = case+ t of
 | TmVar v       => index (g, v, 1)
 | TmLam (v,t)   => TmLam1 (nf_alpha(v::g, t))
 | TmApp (t1,t2) => TmApp1 (nf_alpha (g, t1), nf_alpha (g, t2))

fun found (v: name, l: set): bool = case+ l of
 | v'::l  => if v = v' then true else found (v, l) | nil () => false 

fun insert (v: name, l: set): list0 string = if found(v,l) then l else v::l

fun unite (l: set, l': set): set = case+ l of
 | nil () => l' | v::l => insert (v, unite (l, l'))

fun vars (t: term): set = case+ t of
 | TmVar v       => v :: nil
 | TmLam (v,t)   => insert (v, (vars t))
 | TmApp (t1,t2) => unite (vars t1, vars t2)

fun fresh (v: name, l: set): name = if found(v,l) then fresh (v+"\'", l) else v

fun substVar (t: term, x: name, y: name): term = case+ t of
 | TmVar v       => TmVar (if v = x then y else v)
 | TmLam (v,t)   => if v = x then t else TmLam (v,substVar(t,x,y))
 | TmApp (t1,t2) => TmApp (substVar (t1, x, y), substVar (t2, x, y))

fun subst (t: term, x: name, s: term): term = case+ t of
 | TmVar v       => if x = v then s else t
 | TmLam (v,t)   => let val v' = fresh ("v", unite (vars t, vars s))
                    in TmLam (v',subst(substVar(t,v,v'),x,s)) end
 | TmApp (t1,t2) => TmApp (subst (t1,x,s), subst (t2,x,s))

implement main () = let
  val t = TmApp (TmLam ("y", TmLam ("x", TmApp (TmVar "x", TmVar "y"))), TmLam ("x", TmVar "z"))
  val s = TmApp (TmLam ("y", TmLam ("x", TmApp (TmVar "x", TmVar "y"))), TmLam ("x", TmVar "x"))
in
  print_newline (); printTerm (subst (t, "z", s)); print_newline(); printTerm1 (nf_alpha (nil, t)); print_newline ()
end // end of [main]

(* ****** ****** *)

(* end of [alnf_subst.dats] *)