(*
** Course: Concepts of Programming Languages (BU CAS CS 320)
** Semester: Summer I, 2009
** Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
*)

//
// Author: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Time: Wednesday, June 23, 2009
//

(* ****** ****** *)

#include "BUCASCS320.hats"

(* ****** ****** *)

// [x] must be nonnegative!!!
fun fact_main {i:int | i >= 0} .<i>.
  (x: int i): int =
  if x > 0 then x * fact_main (x-1) else 1
// end of [fact_main]

extern fun fact (x: int): int

// int1_of_int (x: int): [i:int] int i 

fun fact (x: int): int = let
  val x = int1_of_int (x) in
  if x >= 1 then fact_main (x) else $raise Domain ()
end // end of [fact]

val () = printf ("fact(10) = %i\n", @(fact 10))

(* ****** ****** *)

(*

// this specification is imprecise
extern fun{elt:t@ype} bsearch
  (f: int -> elt, n: int, v: elt): int 
  
*)

typedef cmp (elt: t@ype) = (elt, elt) -<fun> int
typedef natLt (n: int) = [i:nat | i < n] int (i)

extern fun{elt:t@ype} bsearch_cloref {n:nat}
  (f: natLt n -<cloref> elt, n: int n, v: elt, cmp: cmp elt)
  :<> [i:int | ~1 <= i; i < n] int i

implement{elt} bsearch_cloref {n}
  (f, n, v, cmp) = loop (0, n-1) where {
  typedef res_t = [i:int | ~1 <= i; i < n] int i
  fun loop
    {i,j:int | 0 <= i; i <= j+1; j+1 <= n} .<j-i+1>.
    (lb: int i, ub: int j):<cloref> res_t =
    if lb <= ub then let
      val m = (lb + ub) / 2
      val sgn = cmp (v, f (m))
    in
      if sgn >= 0 then loop (m+1, ub) else loop (lb, m-1) 
    end else ub
  // end of [loop]  
} // end of [bsearch_cloref]

extern fun{elt:t@ype} bsearch_arr {n:nat}
  (A: array (elt, n), n: int n, v: elt, cmp: cmp elt)
  : [i:int | ~1 <= i; i < n] int i

implement{elt} bsearch_arr {n} (A, n, v, cmp) =
  bsearch_cloref<elt> {n} (lam i => $effmask_ref (A[i]), n, v, cmp)
// end of [bsearch_arr]

extern fun bsearch_str {n:nat}
  (s: string n, n: int n, v: char)
  : [i:int | ~1 <= i; i < n] int i

implement bsearch_str {n} (s, n, v) =
  bsearch_cloref<char> {n} (lam i => s[i], n, v, lam (c1, c2) => compare (c1, c2))
// end of [bsearch_str]

(* ****** ****** *)

#define MAX2 0x8000 // 2 ^ 15

extern fun intsqrt {n:nat} (n: int n): int

implement intsqrt (n) = bsearch_cloref<int> (
  lam x => x * x, MAX2, n, lam (x1, x2) => compare (x1, x2)
) // end of [intsqrt]

(* ****** ****** *)

#define MAX3 0x400 // 2 ^ 10

extern fun intcbrt {n:nat} (n: int n): int

implement intcbrt (n) = bsearch_cloref<int> (
  lam x => x * x * x, MAX3, n, lam (x1, x2) => compare (x1, x2)
) // end of [intcbrt]

(* ****** ****** *)

val () = () where { 
  #define N 26
  val alphabet = "abcdefghijklmnopqrstuvwxyz"
  val ans_a = bsearch_str (alphabet, N, 'a')
  val () = printf ("ans_a = %i\n", @(ans_a)) 
  val ans_m = bsearch_str (alphabet, N, 'm')
  val () = printf ("ans_m = %i\n", @(ans_m)) 
  val ans_z = bsearch_str (alphabet, N, 'z')
  val () = printf ("ans_z = %i\n", @(ans_z)) 
  val ans_A = bsearch_str (alphabet, N, 'A')
  val () = printf ("ans_A = %i\n", @(ans_A)) 
  val sqrt_1023 = intsqrt (1023)
  val () = printf ("intsqrt(1023) = %i\n", @(sqrt_1023))
  val sqrt_1024 = intsqrt (1024)
  val () = printf ("intsqrt(1024) = %i\n", @(sqrt_1024))
  val sqrt_1025 = intsqrt (1025)
  val () = printf ("intsqrt(1025) = %i\n", @(sqrt_1025))
  val cbrt_999 = intcbrt (999)
  val () = printf ("intcbrt(999) = %i\n", @(cbrt_999))
  val cbrt_1000 = intcbrt (1000)
  val () = printf ("intcbrt(1000) = %i\n", @(cbrt_1000))
  val cbrt_1001 = intcbrt (1001)
  val () = printf ("intcbrt(1001) = %i\n", @(cbrt_1001))
} // end of [val]

(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [code-2009-06-23.dats] *)