(*
** Course: Concepts of Programming Languages (BU CAS CS 320)
** Semester: Summer I, 2009
** Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
*)

//
// Author: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Time: June 1, 2009
//

(* ****** ****** *)

// some examples of higher-order functions

(* ****** ****** *)

//
// [find_zero] finds a number n such that f(n) = 0
//

fun find_zero (f: int -> int): int = let
  fun aux (f: int -> int, i: int): int =
    if f (i) = 0 then i
    else (
      if i <= 0 then aux (f, ~i + 1) else aux (f, ~i)
    ) // end of [if]
in
  aux (f, 0)
end // end of [find_zero]

fun f1 (x: int): int = (x - 10) * (x + 11)
// f1' (x) = 2 * x + 1 

fun f2 (x: int): int = (x - 12) * (x + 11)
// f2' (x) = 2 * x - 1 

(* ****** ****** *)

typedef real = double

#define DELTA 1E-10
fun deriv (f: real -> real, x: real): real = // f' (x)
  (f (x + DELTA) - f (x)) / DELTA

fun f3 (x: real): real = (x - 10.0) * (x + 11.0)
fun f4 (x: real): real = (x - 12.0) * (x + 11.0)

(* ****** ****** *)

#define N_newton_raphson 10

fun newton_raphson
  (f: real -> real, x0: real)
  : real = loop (f, x0, 0) where {
  fun loop (f: real -> real, x: real, n: int): real =
    if n < N_newton_raphson then let
      val x_next = x - f (x) / deriv (f, x)
    in
      loop (f, x_next, n+1)  
    end else begin
      x // loop exits
    end // end of [if]
} // end of [newton]

(* ****** ****** *)

#define N_binary 40

// implement a binary search for roots
// x0 <= x1; f(x0) <= 0 ; f(x1) >= 0
fun binary (f: real -> real, x0: real, x1: real): real = let
  fun loop (f: real -> real, x0: real, x1: real, n: int): real =
    if n < N_binary then let
      val x_average = (x0 + x1) / 2
      val f_average = f (x_average)
    in
      if f_average >= 0.0 then loop (f, x0, x_average, n+1)
                          else loop (f, x_average, x1, n+1)  
    end else begin
      (x0 + x1) / 2 // loop exits 
    end // end of [if]
in
  loop (f, x0, x1, 0)
end // end of [binary]

(* ****** ****** *)

fun square_minus_2 (x: real): real = x * x - 2.0

(* ****** ****** *)

val () = () where { // test
  val rt1 = find_zero (f1)
  val () = printf ("f1 (%i) = 0\n", @(rt1))
  val rt2 = find_zero (f2)
  val () = printf ("f2 (%i) = 0\n", @(rt2))
  val f3'_10 = deriv (f3, 10.0) 
  val () = printf ("f3'(10) = %.10f\n", @(f3'_10))
  val f4'_10 = deriv (f4, 10.0) 
  val () = printf ("f4'(10) = %.10f\n", @(f4'_10))
  val rt3 = newton_raphson (f3, 0.0)
  val () = printf ("f3 (%.10f) = 0\n", @(rt3))
  val rt4 = newton_raphson (f4, 0.0)
  val () = printf ("f4 (%.10f) = 0\n", @(rt4))
  val sqrt2_new = newton_raphson (square_minus_2, 1.0)
  val () = printf ("sqrt2_new = %.10f\n", @(sqrt2_new))
  val sqrt2_bin = binary (square_minus_2, 1.0, 2.0)
  val () = printf ("sqrt2_bin = %.10f\n", @(sqrt2_bin))
} // end of [val]

(* ****** ****** *)

// some higher-order list functions

(* ****** ****** *)

#define :: list0_cons
#define cons list0_cons
#define nil list0_nil

(* ****** ****** *)

extern fun{a,b:t@ype} list_map (xs: list0 a, f: a -> b): list0 b

implement{a,b}
  list_map (xs, f) = map (xs, f) where {
  fun map (xs: list0 a, f: a -> b): list0 b =
    case+ xs of
    | x :: xs => f (x) :: map (xs, f)
    | nil () => nil ()
} // end of [list_map]

extern fun{a,b:t@ype} list_foldl (xs: list0 a, f: (b, a) -> b, init: b): b

implement{a,b} // a tail-recursive function
  list_foldl (xs, f, init) = foldl (xs, f, init) where {
  fun foldl (xs: list0 a, f: (b, a) -> b, init: b): b =
    case+ xs of
    | x :: xs => foldl (xs, f, f (init, x)) // a tail-call
    | nil () => init
  // end of [foldl]
} // end of [list_foldl]

extern fun{a,b:t@ype} list_foldr (xs: list0 a, f: (a, b) -> b, res: b): b

implement{a,b}
  list_foldr (xs, f, res) = foldr (xs, f, res) where {
  fun foldr (xs: list0 a, f: (a, b) -> b, res: b): b =
    case+ xs of
    | x :: xs => f (x, foldr (xs, f, res)) // this is *not* a tail call
    | nil () => res
  // end of [foldr]
} // end of [list_foldr]

(* ****** ****** *)

val () = () where { // test
  val xs = 1 :: 2 :: 3 :: 4 :: 5 :: 6 :: nil ()
  val n = list_foldl<int,int> (xs, f, 0) where {
    fun f (init: int, _: int) = init + 1
  }
  val () = printf ("length (foldl) = %i\n", @(n))
  val n = list_foldr<int,int> (xs, f, 0) where {
    fun f (_: int, res: int) = res + 1
  }
  val () = printf ("length (foldr) = %i\n", @(n))
  val sum = list_foldr<int,int> (xs, f, 0) where {
    fun f (x: int, res: int) = x + res
  }
  val () = printf ("summation (foldr) = %i\n", @(sum))
  val ys = list_map (xs, f) where {
    fun f (x: int): int = x + x
  }
  val _(*int*) = list_foldl<int,int> (ys, f, 0) where {
    fun f (init: int, x: int): int = (if init > 0 then print ", "; print x; init+1)
  }
  val () = print_newline ()
} // end of [val]

(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [code-2009-06-01.dats] *)