//
// Course: BU CAS CS 520, Fall 2009
// Instructor: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
//

(* ****** ****** *)

dataprop SUM (int, int) =
  | SUMbas (0, 0) of () // SUBbas (): SUM (0, 0)
  | {n:nat} {r:int} SUMind (n+1, n+1+r) of SUM (n, r)
// end of [SUM]

prfun sum_istot
  {n:nat} .<n>. (): [r:int] SUM (n, r) =
  sif n > 0 then SUMind (sum_istot {n-1} ()) else SUMbas ()
// end of [sum_istot]

prfun sum_isfun
  {n:nat} {r1,r2:int} .<n>.
  (pf1: SUM (n, r1), pf2: SUM (n, r2)): [r1 == r2] void =
  case+ (pf1, pf2) of
  | (SUMind pf1, SUMind pf2) => sum_isfun (pf1, pf2)
  | (SUMbas (), SUMbas ()) => ()
// end of [sum_isfun]

(* ****** ****** *)

dataprop SUM2LOOP (int, int, int) =
  | {r:int} SUM2LOOPbas (0, r, r)
  | {n:nat;r:int;s:int} SUM2LOOPind (n+1, r, s) of SUM2LOOP (n, n+1+r, s)
// end of [SUM2LOOP]

propdef SUM2 (n: int, s: int) = SUM2LOOP (n, 0, s)

prfun lemma_sum_sum2loop {n:nat} {s1,r1,s2:int} .<n>.
  (pf1: SUM (n, s1), pf2: SUM2LOOP (n, r1, s2)): [s1+r1==s2] void =
  case+ pf1 of
  | SUMind pf1 => let
      prval SUM2LOOPind (pf2) = pf2
      prval () = lemma_sum_sum2loop (pf1, pf2)
    in
      // nothing
    end // end of [SUMind]
  | SUMbas () => let
      prval SUM2LOOPbas () = pf2 in ()
    end // end of [SUMbas]
// end of [lemma2_sum_sum2loop]

prfun lemma_sum_sum2 {n:nat} {s1,s2:int} .<>.
  (pf1: SUM (n, s1), pf2: SUM2 (n, s2)): [s1==s2] void =
  lemma_sum_sum2loop (pf1, pf2)
// end of [lemma2_sum_sum2]

(* ****** ****** *)

(*

dataprop MUL (int, int, int) = ...
// [MUL] is defined in $ATSHOME/prelude/basics_dyn.dats
// various lemmas on MUL can be found in $ATSHOME/prelude/SATS/arith.sats

*)

prfun lemma_sum_mul {n:nat} {r1,r2:int} .<n>.
  (pf1: SUM (n, r1), pf2: MUL (n, n+1, r2)): [r1+r1==r2] void =
  case+ pf1 of
  | SUMbas () => let
      prval MULbas () = pf2
    in
      // nothing
    end // end of [SUMbas]
  | SUMind (pf11) => let // pf11: SUM (n-1, r11) // r1 = r11+n
      prval pf21 = mul_istot {n-1,n} () // pf21 : MUL (n-1, n, r12)
      prval () = lemma_sum_mul (pf11, pf21) // r11+r11==r12
      prval pf2_alt = mul_expand_linear {1,1}{1,1} (pf21)
      prval () = mul_isfun (pf2, pf2_alt) // r2 = r12+2*n = 2*(r11+n)
    in
      ()
    end // end of [SUMind]
// end of [lemma_sum_mul]

(* ****** ****** *)

fun sum {n:nat}
  (x :int n): [r:int] (SUM (n, r) | int r) =
  if x > 0 then let
    val (pf1 | r1) = sum (x - 1) in (SUMind pf1 | x + r1)
  end else begin
    (SUMbas () | 0)
  end // end of [if]
// end of [sum]

fun sum {n:nat}
  (x :int n): [r:int] (SUM (n, r) | int r) = let
  fun loop {i,j:nat} (i: int i, j: int j): [r:int] (SUM (i, r) | int (r+j)) =
    if i > 0 then let
      val (pf1 | ans) = loop (i-1, i+j) // pf1: SUM (i-1, r1)
    in
      (SUMind (pf1) | ans)
    end else begin
      (SUMbas () | j)
    end
  // end of [loop]
in
  loop (x, 0)
end // end of [sum]

fun sum {n:nat} .<>.
  (n: int n): [r:int] (SUM (n, r) | int r) = let
  prval pf_sum = sum_istot {n} ()
  val (pf_mul | r2) = n imul2 (n+1)
  prval () = lemma_sum_mul (pf_sum, pf_mul)
in
  (pf_sum | r2 / 2)
end // end of [sum]

(* ****** ****** *)

implement main (argc, argv) = let
  val N0 = 100
  val N = (
    if argc >= 2 then int_of_string (argv.[1]) else N0
  ) : int
  val N = int1_of_int (N)
  val () = assert_prerrf_bool1 (N >= 0, "Usage: %s <natural>\n", @(argv.[0]))
  val (pf | r) = sum N
in
  printf ("sum(%i) = %i\n", @(N, r))
end // end of [main]

(* ****** ****** *)

(* end of [2009-09-15.dats] *)