```(*
** Course: Concepts of Programming Languages (BU CAS CS 320)
** Semester: Summer I, 2009
** Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
*)

//
// Author: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Time: June 1, 2009
//

(* ****** ****** *)

// some examples of higher-order functions

(* ****** ****** *)

//
// [find_zero] finds a number n such that f(n) = 0
//

fun find_zero (f: int -> int): int = let
fun aux (f: int -> int, i: int): int =
if f (i) = 0 then i
else (
if i <= 0 then aux (f, ~i + 1) else aux (f, ~i)
) // end of [if]
in
aux (f, 0)
end // end of [find_zero]

fun f1 (x: int): int = (x - 10) * (x + 11)
// f1' (x) = 2 * x + 1

fun f2 (x: int): int = (x - 12) * (x + 11)
// f2' (x) = 2 * x - 1

(* ****** ****** *)

typedef real = double

#define DELTA 1E-10
fun deriv (f: real -> real, x: real): real = // f' (x)
(f (x + DELTA) - f (x)) / DELTA

fun f3 (x: real): real = (x - 10.0) * (x + 11.0)
fun f4 (x: real): real = (x - 12.0) * (x + 11.0)

(* ****** ****** *)

#define N_newton_raphson 10

fun newton_raphson
(f: real -> real, x0: real)
: real = loop (f, x0, 0) where {
fun loop (f: real -> real, x: real, n: int): real =
if n < N_newton_raphson then let
val x_next = x - f (x) / deriv (f, x)
in
loop (f, x_next, n+1)
end else begin
x // loop exits
end // end of [if]
} // end of [newton]

(* ****** ****** *)

#define N_binary 40

// implement a binary search for roots
// x0 <= x1; f(x0) <= 0 ; f(x1) >= 0
fun binary (f: real -> real, x0: real, x1: real): real = let
fun loop (f: real -> real, x0: real, x1: real, n: int): real =
if n < N_binary then let
val x_average = (x0 + x1) / 2
val f_average = f (x_average)
in
if f_average >= 0.0 then loop (f, x0, x_average, n+1)
else loop (f, x_average, x1, n+1)
end else begin
(x0 + x1) / 2 // loop exits
end // end of [if]
in
loop (f, x0, x1, 0)
end // end of [binary]

(* ****** ****** *)

fun square_minus_2 (x: real): real = x * x - 2.0

(* ****** ****** *)

val () = () where { // test
val rt1 = find_zero (f1)
val () = printf ("f1 (%i) = 0\n", @(rt1))
val rt2 = find_zero (f2)
val () = printf ("f2 (%i) = 0\n", @(rt2))
val f3'_10 = deriv (f3, 10.0)
val () = printf ("f3'(10) = %.10f\n", @(f3'_10))
val f4'_10 = deriv (f4, 10.0)
val () = printf ("f4'(10) = %.10f\n", @(f4'_10))
val rt3 = newton_raphson (f3, 0.0)
val () = printf ("f3 (%.10f) = 0\n", @(rt3))
val rt4 = newton_raphson (f4, 0.0)
val () = printf ("f4 (%.10f) = 0\n", @(rt4))
val sqrt2_new = newton_raphson (square_minus_2, 1.0)
val () = printf ("sqrt2_new = %.10f\n", @(sqrt2_new))
val sqrt2_bin = binary (square_minus_2, 1.0, 2.0)
val () = printf ("sqrt2_bin = %.10f\n", @(sqrt2_bin))
} // end of [val]

(* ****** ****** *)

// some higher-order list functions

(* ****** ****** *)

#define :: list0_cons
#define cons list0_cons
#define nil list0_nil

(* ****** ****** *)

extern fun{a,b:t@ype} list_map (xs: list0 a, f: a -> b): list0 b

implement{a,b}
list_map (xs, f) = map (xs, f) where {
fun map (xs: list0 a, f: a -> b): list0 b =
case+ xs of
| x :: xs => f (x) :: map (xs, f)
| nil () => nil ()
} // end of [list_map]

extern fun{a,b:t@ype} list_foldl (xs: list0 a, f: (b, a) -> b, init: b): b

implement{a,b} // a tail-recursive function
list_foldl (xs, f, init) = foldl (xs, f, init) where {
fun foldl (xs: list0 a, f: (b, a) -> b, init: b): b =
case+ xs of
| x :: xs => foldl (xs, f, f (init, x)) // a tail-call
| nil () => init
// end of [foldl]
} // end of [list_foldl]

extern fun{a,b:t@ype} list_foldr (xs: list0 a, f: (a, b) -> b, res: b): b

implement{a,b}
list_foldr (xs, f, res) = foldr (xs, f, res) where {
fun foldr (xs: list0 a, f: (a, b) -> b, res: b): b =
case+ xs of
| x :: xs => f (x, foldr (xs, f, res)) // this is *not* a tail call
| nil () => res
// end of [foldr]
} // end of [list_foldr]

(* ****** ****** *)

val () = () where { // test
val xs = 1 :: 2 :: 3 :: 4 :: 5 :: 6 :: nil ()
val n = list_foldl<int,int> (xs, f, 0) where {
fun f (init: int, _: int) = init + 1
}
val () = printf ("length (foldl) = %i\n", @(n))
val n = list_foldr<int,int> (xs, f, 0) where {
fun f (_: int, res: int) = res + 1
}
val () = printf ("length (foldr) = %i\n", @(n))
val sum = list_foldr<int,int> (xs, f, 0) where {
fun f (x: int, res: int) = x + res
}
val () = printf ("summation (foldr) = %i\n", @(sum))
val ys = list_map (xs, f) where {
fun f (x: int): int = x + x
}
val _(*int*) = list_foldl<int,int> (ys, f, 0) where {
fun f (init: int, x: int): int = (if init > 0 then print ", "; print x; init+1)
}
val () = print_newline ()
} // end of [val]

(* ****** ****** *)

implement main () = ()

(* ****** ****** *)

(* end of [code-2009-06-01.dats] *)
```