//
//
// This file is for Assignment 4, BU CAS CS 520, Fall, 2008
//
// Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
//
//

(* ****** ****** *)

// Author: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Time: 2007 (or even earlier)


(* ****** ****** *)

// An implementation of random-access list based on nested datatypes

(* ****** ****** *)

//
// How to compile:
//
// atscc -O3 -o ralist_test ralist.sats ralist.dats
//
// How to test:
//
// ./ralist_test

(* ****** ****** *)

staload "ralist.sats"

(* ****** ****** *)

macdef P x y = '(,(x), ,(y))

(* ****** ****** *)

implement{a} ralist_length (xs) = length (xs) where {
  fun length {n:nat} .<n>. (xs: ralist (a, n)):<> int n = case+ xs of
    | RAnil _ => 0
    | RAone _ => 1
    | RAevn xs => 2 * ralist_length (xs)
    | RAodd (_, ys) => 2 * ralist_length (ys) + 1
} // end of [ralist_length]

(* ****** ****** *)

implement{a} ralist_cons (x0, xs) = cons (x0, xs) where {
  fun cons {n:nat} .<n>. (x0: a, xs: ralist (a, n)):<> ralist (a, n+1) =
    case+ xs of
    | RAnil _ => RAone x0
    | RAone x => RAevn (RAone (P x0 x))
    | RAevn xxs => RAodd (x0, xxs)
    | RAodd (x, xxs) => RAevn (ralist_cons (P x0 x, xxs))
} // end of [ralist_cons]

(* ****** ****** *)

implement{a} ralist_head (xs) = head (xs) where {
  fun head {n:pos} .<n>. (xs: ralist (a, n)):<> a = case+ xs of
    | RAone x => x
    | RAevn xxs => begin
        let val xx = ralist_head<P a a> xxs in xx.0 end
      end // end of [RAevn]
    | RAodd (x, _) => x
} // end of [ralist_head]

(* ****** ****** *)

implement{a} ralist_uncons (xs, x_r) = uncons (xs, x_r) where {
  fun uncons {n:pos} .<n>.
    (xs: ralist (a, n), x_r: &a? >> a):<> ralist (a, n-1) =
    case+ xs of
    | RAone x => (x_r := x; RAnil ())
    | RAevn xxs => let
        var xx_r: P a a // uninitialized
        val xxs = ralist_uncons<P a a> (xxs, xx_r)
      in
        case+ xxs of
        | RAnil () => (x_r := xx_r.0; RAone xx_r.1)
        | _ =>> (x_r := xx_r.0; RAodd (xx_r.1, xxs))
      end // end of [RAevn]
    | RAodd (x, xxs) => (x_r :=x; RAevn xxs)
  // end of [uncons]
} // end of [ralist_uncons]

(* ****** ****** *)

implement{a} ralist_tail (xs) = tail (xs) where {
  fun tail {n:pos} .<n>. (xs: ralist (a, n)):<> ralist (a, n-1) = case+ xs of
    | RAone x => RAnil ()
    | RAevn xxs => let
        var xx: P a a
        val xxs = ralist_uncons<P a a> (xxs, xx)
      in
        case+ xxs of
        | RAnil () => RAone xx.1 | _ =>> RAodd (xx.1, xxs)
      end // end of [RAevn]
    | RAodd (_, xxs) => RAevn xxs
  // end of [tail]
} // end of [ralist_tail]

(* ****** ****** *)

implement{a} ralist_lookup (xs, i) = lookup<a> (xs, i) where {
  fun{a:t@ype} lookup {n,i:nat | i < n} .<n>. (xs: ralist (a, n), i: int i):<> a =
    case+ xs of
    | RAone x => x
    | RAevn xxs => let
        val x01 = lookup<P a a> (xxs, nhalf i)
      in
        if i nmod 2 = 0 then x01.0 else x01.1
      end // end of [RAevn]
    | RAodd (x, xxs) => begin
        if i = 0 then x else let
          val x01 = lookup<P a a> (xxs, nhalf (i-1))
        in
          if i nmod 2 = 0 then x01.1 else x01.0
        end // end of [if]
      end // end of [RAodd]
  // end of [lookup]
} // end of [ralist_lookup]

(* ****** ****** *)

implement{a} ralist_update(xs, i, x0) = let
  typedef id (a:t@ype) = a -<cloref> a
  fun{a:t@ype} fupdate {n,i:nat | i < n} .<n>.
    (xs: ralist (a, n), i: int i, f: id a):<> ralist (a, n) =
    case+ xs of
    | RAone x => RAone (f(x))
    | RAevn xxs => let
        val i2 = i/2
      in
        if i = i2 + i2 then // [i] is even
          RAevn (fupdate<P a a> (xxs, i2, lam xx => P (f xx.0) (xx.1)))
        else
          RAevn (fupdate<P a a> (xxs, i2, lam xx => P (xx.0) (f xx.1)))
        // end of [if]
      end // end of [RAevn]
    | RAodd (x, xxs) =>
        if i = 0 then RAodd (f x, xxs)
        else let
          val i1 = i - 1; val i2 = i1 / 2
        in
          if i1 = i2 + i2 then // [i1] is even
            RAodd (x, fupdate<P a a> (xxs, i2, lam xx => P (f xx.0) (xx.1)))
          else
            RAodd (x, fupdate<P a a> (xxs, i2, lam xx => P (xx.0) (f xx.1)))
          // end of [if]
        end // end of [if]
      // end of [RAodd]
  // end of [fupdate]
in
  fupdate<a> (xs, i, lam _ => x0)
end // end of [ralist_update]

(* ****** ****** *)

(*

//
// this is a more advanced implementation where closures are explicitly
// constructed
//

dataviewtype closure_ (a:t@ype) =
  {param: viewtype} CLOSURE_ (a) of (param, (param, a) -<fun> a)

fn{a:t@ype} cloapp (c: closure_ a, x: a):<> a = let
  val+ ~CLOSURE_ {..} {param} (p, f) = c; val f = f: (param, a) -<fun> a
in
  f (p, x)
end // end of [cloapp]
  
fun{a:t@ype} fupdate {n,i:nat | i < n} .<n>.
  (xs: ralist (a, n), i: int i, c: closure_ a):<> ralist (a, n) = let
  fn f0 (c: closure_ a, xx: P a a):<> P a a = '(cloapp<a> (c, xx.0), xx.1)
  fn f1 (c: closure_ a, xx: P a a):<> P a a = '(xx.0, cloapp<a> (c, xx.1))
in
  case+ xs of
  | RAone x => RAone (cloapp<a> (c, x))
  | RAevn xxs => let
      val i2 = i / 2; val parity = i - (i2 + i2)
    in
      if parity = 0 then begin
        RAevn (fupdate<P a a> (xxs, i2, CLOSURE_ {P a a} (c, f0)))
      end else begin
        RAevn (fupdate<P a a> (xxs, i2, CLOSURE_ {P a a} (c, f1)))
      end // end of [if]
    end // end of [RAevn]
  | RAodd (x, xxs) => begin
      if i = 0 then RAodd (cloapp<a> (c, x), xxs) else let
        val i1 = i - 1; val i2 = i1 / 2; val parity = i1 - (i2 + i2)
      in
        if parity = 0 then begin
          RAodd (x, fupdate<P a a> (xxs, i2, CLOSURE_ {P a a} (c, f0)))
        end else begin
          RAodd (x, fupdate<P a a> (xxs, i2, CLOSURE_ {P a a} (c, f1)))
        end // end of [if]
      end // end of [if]
    end // end of [RAodd]
end // end of [fupdate]

implement{a} ralist_update (xs, i, x) = let
  dataviewtype box_a = box_a of a // a local dataviewtype
  val f0 = lam (x_box: box_a, _: a): a =<fun> let val+ ~box_a (x) = x_box in x end
in
  fupdate<a> (xs, i, CLOSURE_ (box_a x, f0): closure_ a)
end // end of [ralist_update]

*)

(* ****** ****** *)

fun ralist_gen {n:nat} (n: int n): ralist (int, n) = let
  fun loop {i,j:nat | i+j == n}
    (i: int i, xs: ralist (int, j)): ralist (int, n) =
    if i > 0 then loop (i - 1, ralist_cons (i, xs)) else xs
in
  loop (n, RAnil ())
end // end of [ralist_gen]

(* ****** ****** *)

fn{a:t@ype} ralist_foreach {n:nat}
  (xs: ralist (a, n), f: a -<cloref1> void): void = let
  var x: a // uninitialized
  fun loop {n:nat} {l:addr}
    (pf: !a? @ l | xs: ralist (a, n), p: ptr l, f: a -<cloref1> void): void =
    case+ xs of
    | RAnil () => ()
    | _ =>> let
        val xs = ralist_uncons<a> (xs, !p); val () = f (!p)
      in
        loop (pf | xs, p, f)
      end
  // end of [loop]
in
  loop (view@ x | xs, &x, f)
end // end of [ralist_foreach]

(* ****** ****** *)

implement main () = let
  val xs = ralist_gen (100)
  val () = ralist_foreach (xs, lam x => (print x; print_newline ()))
  val n = ralist_length<int> (xs)
  val () = begin
    print "n(100) = "; print n; print_newline ()
  end
  val x = ralist_lookup<int> (xs, 50)
  val () = begin
    print "x(51) = "; print x; print_newline ()
  end
  val xs = ralist_update<int> (xs, 50, ~51)
  val x = ralist_lookup<int> (xs, 50)
  val () = begin
    print "x(-51) = "; print x; print_newline ()
  end
in
  // empty
end // end [main]

(* ****** ****** *)

(* end of [ralist.dats] *)