(*
// This file is for Assignment 3, BU CAS CS 520, Fall, 2009
//
// Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
//
*)

(* ****** ****** *)

//
// How to compile:
//   atscc -o power power.dats power.sats
// How to test:
//   ./power fib_lin 1000000
//   ./power fib_log 1000000 // this one tests your implementation!
//

(* ****** ****** *)

staload "power.sats"

(* ****** ****** *)

prfun power_istot
  {x:num} {n:nat} .<n>. (): [r:num] POWER (x, n, r) =
  sif n == 0 then let
    prval one = one_istot () in POWERbas (one)
  end else let
    prval [r1:num] pf_pow = power_istot {x} {n-1} ()
    prval pf_mul = mul_istot {x,r1} ()
  in
    POWERind (pf_pow, pf_mul)
  end // end of [sif]
// end of [power_istot]

(* ****** ****** *)

prfun power_isfun
  {x:num} {n:nat} {r1,r2:num} .<n>. (
    pf1: POWER (x, n, r1), pf2: POWER (x, n, r2)
  ) : EQUAL (r1, r2) =
  case+ (pf1, pf2) of
  | (POWERbas pf10, POWERbas pf20) => one_isfun (pf10, pf20)
  | (POWERind (pf11, pf12), POWERind (pf21, pf22)) => let
      prval EQUAL () = power_isfun (pf11, pf21)
      prval EQUAL () = mul_isfun (pf12, pf22)
    in
      EQUAL ()
    end // end of [POWERind _, POWERind _]
// end of [power_isfun]

(* ****** ****** *)

implement power_lemma1 {x} {n} (pf1, pf2) = let
  prval [_1:num] pf_one = one_istot ()
  prval pf_x0 = POWERbas (pf_one)
  prval pf_x_1_x = mul_istot {x,_1} ()
  prval EQUAL () = mul_one_unit_r (pf_one, pf_x_1_x)
  prval pf_x1 = POWERind (pf_x0, pf_x_1_x)
in
  power_lemma3 (pf1, pf_x1, pf2)
end // end of [power_lemma1]

(* ****** ****** *)

implement power_lemma2 (pf1, pf2) = let
  prfun lemma2 {x:num} {n:nat} {xx,xxn:num} .<n>.
    (pf1: MUL_num (x, x, xx), pf2: POWER (xx, n, xxn))
    : POWER (x, n+n, xxn) = case+ pf2 of
    | POWERbas (pf_one) => POWERbas (pf_one)
    | POWERind (pf21, pf22) => let
        prval pf1_pow = lemma2 (pf1, pf21) // x^{n1+n1} = xxn1
        prval pf_x_xxn1 = mul_istot {..} ()
        prval pf2_pow = POWERind (pf1_pow, pf_x_xxn1)
        prval pf_x_x_xxn1 = mul_assoc (pf1, pf_x_xxn1, pf22)
      in
        POWERind (pf2_pow, pf_x_x_xxn1)
      end // end of [POWERind]
  // end of [lemma2]
in
  lemma2 (pf1, pf2)
end // end of [power_lemma2]

(* ****** ****** *)

implement power_lemma3 (pf1, pf2, pf3) = let
  prfun lemma3 {x:num}
    {n1:nat} {n2:nat} {xn1,xn2,xn:num} .<n1>. (
    pf1: POWER (x, n1, xn1)
  , pf2: POWER (x, n2, xn2)
  , pf3: MUL_num (xn1, xn2, xn)
  ) :<> POWER (x, n1+n2, xn) = case+ pf1 of
  | POWERbas (pf_one) => let
      prval EQUAL () = mul_one_unit_l (pf_one, pf3) in pf2
    end // end of [POWERbas]
  | POWERind {..}{..} (pf11, pf12) => let
      prval pf4 = mul_istot () // pf4: x^{n1-1} * x^n2 = ?1
      prval pf_res = lemma3 {x} {n1-1} {n2} (pf11, pf2, pf4)
      prval pf5 = mul_assoc (pf12, pf4, pf3) // pf5: x * ?1 = xn
    in
      POWERind (pf_res, pf5)
    end // end of [POWERind]
in
  lemma3 (pf1, pf2, pf3)
end // end of [power_lemma3]

(* ****** ****** *)

(*

//
// this is a reference implementation of O(n)-time complexity
// please uncomment the code if you would like to compile it
//
implement power (x, n) = power_rec (x, n) where {
  fun power_rec {x:num} {n:nat} .<n>.
    (x: N x, n: int n):<> [r:num] (POWER (x, n, r) | N r) =
    if n > 0 then let
      val (pf_pow | r) = power_rec (x, n-1); val (pf_mul | r1) = x * r
    in
      (POWERind (pf_pow, pf_mul) | r1)
    end else let
      val (pf_one | _1) = one in (POWERbas pf_one | _1)
    end // end of [if]
  // end of [power_rec]
} // end of [power]

*)

(* ****** ****** *)

//
// please put your code here
//

implement power (x0, n) = let 
  fun power_tail {x:num}
    {n:nat} {a:num} .<n>. (
      x: N x, n: int n, a: N a // (x^n * a) = x0^n
    ) :<> [xn,r:num] (
      POWER (x, n, xn), MUL_num (xn, a, r) | N r
    ) =
    if n > 0 then let // [n] is odd
      val (pf_xx | xx) = x * x
      val n2 = n / 2; val i = n - (2 * n2) // i = 0/1 (even/odd)
    in
      if i > 0 then let // n is odd
        val (pf_xa | xa) = x * a	
        val (pf1_res, pf2_res | r) = power_tail (xx, n2, xa)
        prval pf1 = power_lemma2 (pf_xx, pf1_res) // : POWER (x, n2+n2, xx_n2)
        prval pf2 = mul_istot () // xx_n2 * x = ?1
        prval pf3 = power_lemma1 (pf1, pf2) // x^n = ?1
        prval pf4 = mul_istot () // ?1 * a = ?2
        prval pf5 = mul_assoc (pf2, pf_xa, pf4) // xx_n2 * xa = ?2
        prval EQUAL () = mul_isfun (pf2_res, pf5) // r = ?2
      in
        (pf3, pf4 | r)
      end else let // n is even
        val (pf1_res, pf2_res | r) = power_tail (xx, n2, a)
        prval pf1 = power_lemma2 (pf_xx, pf1_res) // : POWER (x, n2+n2, xx_n2)
      in
        (pf1, pf2_res | r)
      end // end of [if]
    end else let
      prval [_1:num] pf_one = one_istot ()
      prval pf1 = POWERbas pf_one
      prval pf2 = mul_istot () // _1 * a = ?1
      prval EQUAL () = mul_one_unit_l (pf_one, pf2) // ?1 = a
    in
      (pf1, pf2 | a)
    end // end of [if]
  // end of [power_tail]
  val (pf_one | _1) = one
  val (pf1_res, pf2_res | r) = power_tail (x0, n, _1)
  prval EQUAL () = mul_one_unit_r (pf_one, pf2_res)
in
  (pf1_res | r)
end // end of [power]

(* ****** ****** *)

//
// please do not change any of the following code; it is to
// be used for testing when your implementation is finished.
//
//
// How to compile:
//
// atscc -O3 -o power_test power.dats power.sats
//
// How to test:
//
// ./power_test "fib_lin" <integer>
// ./power_test "fib_log" <integer>
//

(* ****** ****** *)

assume ONE_num (_:num) = unit_p
assume MUL_num (_: num, _: num, _: num) = unit_p

typedef int = int64
#define intof int64_of_int

(* ****** ****** *)

// tail-recursive O(n)-time
fn fib_lin (n: Nat): int = let
  fun loop {n,i:nat | i <= n} .<n-i>.
    (n: int n, i: int i, r1: int, r2: int): int =
    if i < n then loop (n, i+1, r2, r1+r2) else r1
in
  loop (n, 0, intof 0, intof 1)
end // end of [fib_lin]

(* ****** ****** *)

typedef int4 = @(int, int, int, int)
extern typedef "int4_t" = int4

// assume N (_:num) = @(int, int, int, int)
extern fun int4_of_num {x:num} (x: N x):<> int4 = "int4_of_num"
extern fun num_of_int4 (i4: int4):<> [x:num] N x = "num_of_int4"

fn num_make
  (x1: int, x2: int, x3: int, x4: int):<> [x:num] N x =
  num_of_int4 @(x1, x2, x3, x4)

implement one =
  (unit_p () | num_make (intof 1, intof 0, intof 0, intof 1))

implement mul_num_num (u, v) = let
  val u = int4_of_num u and v = int4_of_num v
  val w0 = u.0 * v.0 + u.1 * v.2
  val w1 = u.0 * v.1 + u.1 * v.3
  val w2 = u.2 * v.0 + u.3 * v.2
  val w3 = u.2 * v.1 + u.3 * v.3
in
  (unit_p () | num_make (w0, w1, w2, w3))
end // end of [mul_num_num]

val fib_num = num_make (intof 0, intof 1, intof 1, intof 1)

fn fib_log (n: Nat): int = let
  val (_(*pf*) | A) = power (fib_num, n); val A = int4_of_num A
in
  A.1
end // end of [fib]

fn prerr_usage (cmd: string): void = begin
  printf ("%s <flag> <natural>\n", @(cmd));
  printf ("  where flag is either 'fib_lin' or 'fib_log'.\n", @());
end // end of [prerr_usage]

implement main (argc, argv) = let
  val cmd = argv.[0]
  val () =
    if argc < 3 then prerr_usage (cmd)
  val () = assert (argc >= 3)
  val flag = argv.[1]
  val n = int1_of argv.[2]
  val () = if n < 0 then prerr_usage (cmd)
  val () = assert (n >= 0)
  val () = (case+ flag of
    | "fib_lin" => begin
        printf ("fib_lin (%i) = ", @(n)); print (fib_lin n); print_newline ()
      end
    | "fib_log" => begin
        printf ("fib_log (%i) = ", @(n)); print (fib_log n); print_newline ()
      end
    | _ => begin
        printf ("The flag [%s] is unrecognized.\n", @(flag));
        print_string ("The only supported flags are [fib_lin] and [fib_log].");
        print_newline ()
      end // end of [_]
  ) : void // end of [val]
in
  // empty
end // end of [main]

(* ****** ****** *)

assume N (_:num): t@ype = $extype "int4_t"

%{$

typedef int4_t num_t ;
int4_t num_of_int4 (num_t x) { return x ; }
num_t int4_of_num (int4_t x) { return x ; }

%}

(* ****** ****** *)

(* end of [power.dats] *)