//
// Course: BU CAS CS 520
// Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
//

(* ****** ****** *)

extern fun
thread_create (f: () -<lin,cloptr> void): void // detached

absviewtype uplock (int, view)
absviewtype upticket (view)

extern fun
uplock_create {v:view} ():<> uplock (0, v)

extern fun
uplock_destroy {v:view} (x: uplock (1, v)):<> (v | void)

extern fun
upticket_create {v:view} (x: !uplock (0, v) >> uplock (1, v)):<> upticket v

extern fun
upticket_destroy {v:view} (pf: v | x: upticket v):<> void

(* ****** ****** *)

absviewtype tid (view)

extern fun
thread_create_join {v:view} (f: () -<lin,cloptr> (v | void)): tid (v)

extern fun
thread_join {v:view} (id: tid v): (v | void)

local

assume tid (v:view) = uplock (1, v)

in // in of [local]

implement thread_create_join {v} (f) = let
  val lock = uplock_create {v} () // lock: uplock (0, v)
  val tick = upticket_create {v} (lock) // lock: uplock (1, v)
  val f1 = lam () =<lin,cloptr> let
    val (pf | ()) = f () in upticket_destroy (pf | tick)
  end
  val () = thread_create (f1)
in
  lock
end // end of [thread_create_join]

implement thread_join (tid) = uplock_destroy (tid)

end // end of [local]

fun fib (n: int): int =
  if n >= 2 then fib (n-1) + fib (n-2) else n
// end of [fib]

#define N 10
fun fib_mt (n: int): int =
  if n >= N then let
    var res1: int // uninitialized
    viewdef V1 = int @ res1
    var res2: int // uninitialized
    viewdef V2 = int @ res2
    val f1 = lam ()
      : (V1 | void) =<lin,cloptr> let
      val () = res1 := $effmask_all (fib_mt (n-1))
    in
      (view@ res1 | ())
    end
    val f2 = lam ()
      : (V2 | void) =<lin,cloptr> let
      val () = res2 := $effmask_all (fib_mt (n-2))
    in
      (view@ res2 | ())
    end
    val tid1 = thread_create_join {V1} (f1)
    val tid2 = thread_create_join {V2} (f2)
    val (pf1 | ()) = thread_join (tid1)
    val (pf2 | ()) = thread_join (tid2)
    val res = res1 + res2
    prval () = view@ res1 := pf1
    prval () = view@ res2 := pf2
  in
    res
  end else begin
    fib (n)
  end // end of [if]
// end of [fib_mt]

(* ****** ****** *)

(* end of [multithread.dats] *)