(*
** Course: Concepts of Programming Languages (BU CAS CS 320)
** Semester: Summer I, 2009
** Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
*)

//
// Author: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Time: June 8, 2009
//

(* ****** ****** *)

abstype array2 (a:t@ype)

extern fun{a:t@ype}
  array2_make_elt (nrow: int, ncol: int, ini: a): array2 a
// end of [extern]

extern fun{a:t@ype}
  array2_get_elt_at (A: array2 a, ncol: int, i: int, j: int): a
// end of [extern]

extern fun{a:t@ype}
  array2_set_elt_at (A: array2 a, ncol: int, i: int, j: int, x: a): void
// end of [extern]

(* ****** ****** *)

staload _(*anonymous*) = "prelude/DATS/array0.dats"

(* ****** ****** *)

local

assume array2 (a:t@ype) = array0 (a) // using row major representation

in // in of [local]

implement{a:t@ype} array2_make_elt (nrow, ncol, ini) = let
  val asz = nrow * ncol
  val asz = int1_of_int asz
  val () = assert (asz >= 0)
  val asz = size_of_int1 (asz)
in
  array0_make_elt (asz, ini)
end // end of [array2_make_elt]

implement{a}
  array2_get_elt_at (A, ncol, i, j) = let
  val n = i * ncol + j
  val n = int1_of_int n   
  val () = assert (n >= 0)
  val n = size_of_int1 (n)
(*
  val () = (print "get: n = "; print n; print "\n")
  val () = printf ("get: i = %i and j = %i\n", @(i, j))
*)
in
  array0_get_elt_at<a> (A, n)
end // end of [array2_get_elt]

implement{a}
  array2_set_elt_at (A, ncol, i, j, x) = let
  val n = i * ncol + j
  val n = int1_of_int n   
  val () = assert (n >= 0)
  val n = size_of_int1 (n)
(*
  val () = (print "set: n = "; print n; print "\n")
  val () = printf ("set: i = %i and j = %i\n", @(i, j))
*)
in
  array0_set_elt_at<a> (A, n, x)
end // end of [array2_set_elt]

end // end of [local]

(* ****** ****** *)

staload Math = "libc/SATS/math.sats"

(* ****** ****** *)

#define FLOAT_FOR_REAL 1
#define DOUBLE_FOR_REAL 0

#if FLOAT_FOR_REAL #then
typedef real = float
macdef real_of_double(x) = float_of_double ,(x)
macdef int_of_real(x) = int_of_float ,(x)
macdef sqrt_real(x) = $Math.sqrtf ,(x)
#endif

#if DOUBLE_FOR_REAL #then
typedef real = double
macdef real_of_double(x) = ,(x)
macdef int_of_real(x) = int_of_double ,(x)
macdef sqrt_real(x) = $Math.sqrt ,(x)
#endif

(* ****** ****** *)

typedef point = @(real, real) // flat representation

fun square (x: real): real = x * x

fun dist_pt_pt
  (pt1: point, pt2: point) = let
  val dx = pt1.0 - pt2.0 and dy = pt1.1 - pt2.1
in
  sqrt_real (dx * dx + dy * dy) 
end // end of [dist_pt_pt]

(* ****** ****** *)

// some params for this program

#define N 1000

(* ****** ****** *)

val NN: int = N * N
val NN_1: double = 1.0 / NN 
val epsilon = real_of_double (1.0 / N)

(* ****** ****** *)

typedef pointlst = list0 (point)
val thePointlstArray = array2_make_elt<pointlst> (N, N, list0_nil)

fun thePointlstArray_clear () = loop1 (0) where {
  fn* loop1 (i: int): void = if i < N then loop2 (i, 0) else ()
  and loop2 (i: int, j: int): void =
    if j < N then let
(*
      val pts = array2_get_elt_at (thePointlstArray, N, i, j)
      val () = list_vt_free (__cast pts) where {
        staload _(*anonymous*) = "prelude/DATS/list_vt.dats"
        extern castfn __cast (pts: list0 point): List_vt point
      } (* end of [val] *)
*)
    in  
      array2_set_elt_at (thePointlstArray, N, i, j, list0_nil); loop2 (i, j+1)
    end else loop1 (i+1)
  // end of [loop2]
} (* end of [thePointlstArray_clear] *)

(* ****** ****** *)

staload Rand = "libc/SATS/random.sats"

extern fun rand_real (): real (* btw 0 and 1 *)

//
// make sure it is between [0, 1-1/NN]
//
implement rand_real () =
  real_of_double ($Rand.drand48 () * (1.0 - NN_1))
// end of [rand_real]

(* ****** ****** *)

extern fun do_one_square
  (pt: point, i: int, j: int ): int
  
implement do_one_square (pt, i, j) = let
  val pts = array2_get_elt_at (thePointlstArray, N, i, j)
  fun loop (pt0: point, pts: pointlst, res: int): int =
    case+ pts of
    | list0_cons (pt, pts) => let
        val dist = dist_pt_pt (pt, pt0)
      in
        if dist >= epsilon then loop (pt0, pts, res)
                           else loop (pt0, pts, res+1)
      end // end of [list0_cons]                     
    | list0_nil () => res
in
  loop (pt, pts, 0(*res*)) 
end (* end of [do_one_square] *)

(* ****** ****** *)

fun do_all_squares .<(*nonrec*)>.
  (pt: point, Nx: int, Ny: int): int = res where {
  var i: int = 0 and j: int = 0; var res: int = 0
  val () = for (i := Nx-1; i <= Nx+1; i := i+1) let
    val () = if i < 0 then continue else (if i >= N then break else ())
    val () = for (j := Ny-1; j <= Ny+1; j := j+1) let
      val () = if j < 0 then continue else (if j >= N then break else ())
    in
      res := res + do_one_square (pt, i, j)
    end (* end of [val] *)
  in
    // empty
  end (* end of [val] *)
} (* end of [do_all_squares] *)

(* ****** ****** *)

fun do_one_round (): int = res where {
  val px = rand_real () and py = rand_real ()
  val pt = @(px, py)
  val Nx: int = int_of_real (N * px)
  val Ny: int = int_of_real (N * py)
(*
  val () = printf ("do_one_round: px = %f and py = %f\n", @(double_of px, double_of py)) 
  val () = printf ("do_one_round: Nx = %i and Ny = %i\n", @(Nx, Ny)) 
*)
  val res = do_all_squares (pt, Nx, Ny)
  val pts =
    array2_get_elt_at (thePointlstArray, N, Nx, Ny)
  val () = array2_set_elt_at
    (thePointlstArray, N, Nx, Ny, list0_cons (pt, pts))
} (* end of [do_one_round] *)

fun do_all_rounds (): int = loop (0, 0(*res*)) where {
  fun loop (i: int, res: int): int =
    if i < NN then loop (i+1, res + do_one_round ()) else res
  // end of [loop]  
  val () = thePointlstArray_clear ()
} // end of [do_all_rounds]

(* ****** ****** *)

implement main () = let
  val () = $Rand.srand48_with_time ()
  #define TIMES 1
  val PI = 2.0 * loop (TIMES, 0) / (TIMES * (NN - 1)) where {
    fun loop (n: int, res: int): int =
      if n > 0 then loop (n-1, res + do_all_rounds ()) else res
    // end of [loop]
  } (* end of [val] *)
in
  printf ("PI = %.6f\n", @(PI))
end (* end of [main] *)

(* ****** ****** *)

(* end of [montecarlo.dats] *)