//
// Initially written in SML by Likai Liu with minor modifications by Hongwei Xi
//
// Translated from SML to ATS by Hongwei Xi, September 2008
//

datatype BraunTree = E | B of (BraunTree, BraunTree)
typedef BT = BraunTree

extern fun print_BraunTree (t: BraunTree): void
overload print with print_BraunTree

implement print_BraunTree (t) = case+ t of
  | E () => print "E" | B (t1, t2) => begin
      print "B("; print t1; print ", "; print t2; print ")"
    end // end of [B]
// end of [print_BraunTree]

fun addOneNode (t: BT): BT = case+ t of
  | E () => B(E, E) | B (t1, t2) => B (addOneNode t2, t1)
// end of [addOneNode]

fun makeBraunTreeFrom (t: BT, n :int): BT = begin
  case+ n of 0 => t | _ => makeBraunTreeFrom (addOneNode t, n-1)
end // end of [makeBraunTreeFrom]

fun pow2 (n: int): int = if n > 0 then 2 * pow2 (n-1) else 1

fn list0_reverse {a:type} (xs: list0 a): list0 a = let
  fun loop (xs: list0 a, ys: list0 a): list0 a = case+ xs of
    | list0_cons (x, xs) => loop (xs, list0_cons (x, ys)) | _ => ys
  // end of [loop]
in
  loop (xs, list0_nil ())
end // end of [list0_reverse]

fun list0_length {a:type} (xs: list0 a): int = begin
  case+ xs of list0_cons (_, xs) => 1 + list0_length (xs) | _ => 0
end // end of [list0_length]

extern fun listBraunTreesOfGivenHeight (height: int) : list0 BT

implement listBraunTreesOfGivenHeight (height) = let
  val base = pow2 (height - 1)
  val templateTree = makeBraunTreeFrom (E, base)
  fun loop (
      start: int, stop: int, stack: list0 BT, lastTempTree: BT
    ) : list0 BT =
    if (start > stop) then list0_reverse (stack) else let
      val newTree = addOneNode (lastTempTree)
    in
      loop (start + 1, stop, list0_cons (lastTempTree, stack), newTree)
    end // end of [if]
  // end of [loop]
in
  loop (base, pow2 height - 1, list0_nil (), templateTree)
end // end of [listBraunTrees]

implement main () = let
  val BTS10 = listBraunTreesOfGivenHeight (10)
  val len10 = list0_length (BTS10)
  val BTS12 = listBraunTreesOfGivenHeight (12)
  val len12 = list0_length (BTS12)
in
  printf ("len10(%i) = %i\n", @(pow2 9, len10));
  printf ("len12(%i) = %i\n", @(pow2 11, len12));
end // end of [main]

(* end of [listBraunTreesOfGivenHeight.dats] *)