//
// Course: BU CAS CS 520
// Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
//

(* ****** ****** *)

#define N 10

fun fact (x: int): int = if x > 0 then x * fact (x-1) else 1

val () = printf ("fact(%i) = %i\n", @(N, fact(N)))

// in continuation-passing style
fun kfact (x: int, k: int -<cloref> int): int =
  if x > 0 then kfact (x-1, lam res => k (x * res)) else k (1)
// end of [kfact]

val K0 = lam (res: int): int =<cloref> res
val () = printf ("kfact(%i, \\res => res) = %i\n", @(N, kfact(N, K0)))

(* ****** ****** *)

fun ack (m: int, n: int): int =
  if m > 0 then
    if n > 0 then ack (m-1, ack (m, n-1)) else ack (m-1, 1)
  else n+1
// end of [ack]

val () = printf ("ack(3, 3) = %i\n", @(ack(3, 3)))

// in continuation-passing style
fun kack (m: int, n: int, k: int -<cloref1> int): int =
  if m > 0 then
    if n > 0 then kack (m, n-1, lam res => kack (m-1, res, k))
    else kack (m-1, 1, k)
  else k (n+1)
// end of [kack]

val () = printf ("kack(3, 3, \\res => res) = %i\n", @(kack(3, 3, K0)))

(* ****** ****** *)

fun findz (f: int -<fun1> int): int = loop (f, 0) where {
  fun loop
    (f: int -<fun1> int, i: int): int =
    if f (i) <> 0 then loop (f, i+1) else i
  // end of [loop]
} // end of [findz]

val ans = findz (lam x => (x - 11) * (x + 10))
val () = printf ("findz (\x => (x - 11) * (x + 10)) = %i\n", @(ans))

(* ****** ****** *)

abst@ype ans_t = int
typedef cont (a:t@ype) = a -<cloref1> ans_t

fun kfindz (
    f: (int, cont int) -<fun1> ans_t
  , k: cont int
  ) : ans_t = loop (f, 0, k) where {
  fun loop (
      f: (int, cont int) -<fun1> ans_t
    , i: int
    , k: cont int
    ) : ans_t =
    f (i, lam res => if res <> 0 then loop (f, i+1, k) else k i)
  // end of [loop]
} // end of [kfindz]

local
assume ans_t = int
in
val ans = kfindz (lam (x, k) => k ((x - 11) * (x + 10)), K0)
val () = printf ("kfindz (\\(x, k) => k((x - 11) * (x + 10)), \\res => res) = %i\n", @(ans))
end // end of local]

(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [cont.dats] *)