//
// Author: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Time: June 2, 2009
//

(* ****** ****** *)

#include "BUCASCS320.hats"

(* ****** ****** *)

//
// PART I: function vs. closure
//

(* ****** ****** *)

// [add] is a top-level function
fun add (x: int, y: int): int = x + y

// [add_cloref] is a closure (carrying an empty environment)
fun add_cloref (x: int, y: int):<cloref1> int = x + y

(* ****** ****** *)

// [f] is a function
fun f (x: int, y: int): int = let
   // [g] is a closure (carrying a nonempty environment
  fun g (z: int):<cloref1> int = x + y + z in g (1)
end // end of [f]

(* ****** ****** *)

// [find_zero] takes as its argument a closure
fun find_zero (f: int -<cloref1> int): int = let
  fun aux (f: int -<cloref1> int, i: int): int =
    if f (i) = 0 then i
    else (
      if i <= 0 then aux (f, ~i + 1) else aux (f, ~i)
    ) // end of [if]
in
  aux (f, 0)
end // end of [find_zero]

(* ****** ****** *)

fun f1 (x: int): int = (x - 10) * (x + 11)
fun f1_cloref (x: int):<cloref1> int = (x - 10) * (x + 11)

(*

// this yields a type error
val rt1 = find_zero (f1) // as [f1] is not a closure
val () = printf ("f1 (%i) = 0\n", @(rt1))

*)

val rt1 = find_zero (f1_cloref)
val () = printf ("f1_cloref (%i) = 0\n", @(rt1))

val rt1 = find_zero (lam (x: int): int => (x - 10) * (x + 11))  
val () = printf ("f1_cloref (%i) = 0\n", @(rt1))

val rt1 = find_zero (lam (x:int): int => f1 (x))
val () = printf ("f1_cloref (%i) = 0\n", @(rt1))

(* ****** ****** *)

//
// PART II: A programming example: Erastothene's sieve
//

(* ****** ****** *)

#define :: list0_cons
#define cons list0_cons
#define nil list0_nil

(* ****** ****** *)

extern fun{a,b:t@ype}
  list0_foldl (xs: list0 a, f: (b, a) -<cloref1> b, init: b): b
// end of [extern fun]  

implement{a,b} // a tail-recursive function
  list0_foldl (xs, f, init) = foldl (xs, f, init) where {
  fun foldl (xs: list0 a, f: (b, a) -<cloref1> b, init: b): b =
    case+ xs of
    | x :: xs => foldl (xs, f, f (init, x)) // a tail-call
    | nil () => init
  // end of [foldl]
} // end of [list0_foldl]

(* ****** ****** *)

extern fun{a:t@ype}
  list0_filter (xs: list0 a, pred: a -<cloref1> bool): list0 a
// end of [extern fun]

implement{a} list0_filter (xs, pred) = filter (xs) where {
  fun filter (xs: list0 a):<cloref1> list0 a =
    case+ xs of
    | x :: xs => if pred (x) then x :: filter (xs) else filter (xs)
    | nil () => nil ()
  // end of [filter]
} // end of [list0_filter]

(* ****** ****** *)

// Eratosthenes' sieve
fun sieve (xs: list0 int): list0 int = case xs of
  | x :: xs => let
      val xs = list0_filter (xs, lam (x1: int) => (x1 mod x <> 0))
    in
      x :: sieve (xs)
    end // end of [::] 
  | nil () => nil ()
// end of [sieve]  

(* ****** ****** *)

fun list0_upto
  (m: int, n: int): list0 int = loop (m, n, nil ()) where {
  fun loop
    (m: int, n: int, res: list0 int): list0 int =
    if m <= n then loop (m, n-1, n :: res) else res
} // end of [list0_upto]

(* ****** ****** *)

implement main () = let
  val nchunk = gc_chunk_count_limit_max_get ()
  val () = (print "nchunk = "; print nchunk; print_newline ()) 
  val () = gc_chunk_count_limit_max_set (nchunk << 6) // enough?
  val primes = sieve (list0_upto (2, 100000))
  val _(*int*) = list0_foldl<int,int> (primes
    , lam (init: int, x: int): int => (print x; print_newline (); init)
    , 0(*init*)
    ) // end of [list0_foldl]
  val nprimes = list0_length<int> primes
  val () = printf ("# of primes = %i\n", @(nprimes))
in
  // nothing
end (* end of [main] *) 

(* ****** ****** *)

(* end of [code-2009-06-02.dats] *)