//
// Course: BU CAS CS 520
// Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
//

(* ****** ****** *)

//
// How to compile:
//   atscc -o numeral numeral.dats
//
// How to test:
//   ./numeral
//

(* ****** ****** *)

typedef tid (a: type) = a -<cloref> a
typedef numeral = {a:type} tid (tid a)

(* ****** ****** *)

fn succ_numeral
  (n_: numeral):<> numeral = lam f x =<cloref> f (n_ f x)
// end of [succ_numeral]

(* ****** ****** *)

(*
** this one is rather complicated!
*)
fn pred_numeral (n_: numeral):<> numeral = let
  typedef pair (a:type) = {t:type} ((a, a) -<> t) -<cloref> t
  fn pair {a:type} (x: a, y: a):<> pair a = lam z => z (x, y)
  val fst = lam {a:type} (x: a, y: a): a =<> x
  val snd = lam {a:type} (x: a, y: a): a =<> y
  val pairf = lam {a:type}
    (f: tid a) (xy: pair a): pair a =<cloref> let val x = xy (fst {a}) in pair (f x, x) end
  // end of [pairf]
in
  lam {a:type} (f: tid a) =<cloref>
    lam (x: a): a =<cloref> n_ {pair a} (pairf {a} f) (pair {a} (x, x)) (snd {a})
  // end of [lam]
end // end of [pred_numeral]

(* ****** ****** *)

fn add_numeral_numeral
  (m_: numeral, n_: numeral):<> numeral =
  lam f x =<cloref> m_ f (n_ f x)
// end of ...

fn sub_numeral_numeral
  (m_: numeral, n_: numeral):<> numeral = n_ {numeral} (lam x =<cloref> pred_numeral x) m_
// end of ...

(* ****** ****** *)

fn mul_numeral_numeral
  (m_: numeral, n_: numeral):<> numeral =
  lam {a:type} (f: tid a): tid a =<cloref> m_ (n_ {a} f)
// end of ...

(* ****** ****** *)

fn pow_numeral_numeral
  (bas_: numeral, exp_: numeral):<> numeral =
  lam {a:type} (f: tid a) =<cloref> exp_ {tid a} (bas_ {a}) (f) 
// end of ...

(* ****** ****** *)

fun numeral_make {n:nat} .<n>. (n: int n): numeral =
  if n > 0 then let
    val n1_ = numeral_make (n-1) in lam f x =<cloref> f (n1_ f x)
  end else
    lam f x =<cloref> x
  // end of [if]
// end of [numeral_make]

(* ****** ****** *)

typedef int = intptr

val _0 = intptr_of_int 0
val _1 = intptr_of_int 1

fn print_numeral
  (n_: numeral): void = let
  val n = n_ {int} (lam (x: int): int =<cloref> x + _1) (_0)
in
  print_intptr (n)
end // end of [print_numeral]

(* ****** ****** *)

#define N 7

val n_ = numeral_make (N)
val () = print "N = "
val () = print_numeral (n_)
val () = print_newline ()

val n1_ = succ_numeral (n_)
val () = printf ("%i + 1 = ", @(N))
val () = print_numeral (n1_)
val () = print_newline ()

val n2_ = pred_numeral (n_)
val () = printf ("%i - 1 = ", @(N))
val () = print_numeral (n2_)
val () = print_newline ()

val n3_ = add_numeral_numeral (n_, n_)
val () = printf ("%i + %i = ", @(N, N))
val () = print_numeral (n3_)
val () = print_newline ()

val n4_ = sub_numeral_numeral (n_, n_)
val () = printf ("%i - %i = ", @(N, N))
val () = print_numeral (n4_)
val () = print_newline ()

val n5_ = mul_numeral_numeral (n_, n_)
val () = printf ("%i * %i = ", @(N, N))
val () = print_numeral (n5_)
val () = print_newline ()

val n6_ = pow_numeral_numeral (n_, n_)
val () = printf ("%i ^ %i = ", @(N, N))
val () = print_numeral (n6_)
val () = print_newline ()

(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [numeral.dats] *)