//
// For parsing context-free grammars
//

//
// Author: Hongwei Xi (hwxi AT cs DOT bu DOT edu)
// Time: February, 2009
//

(* ****** ****** *)

staload "grammar.sats"

(* ****** ****** *)

staload _(*anonymous*) = "prelude/DATS/array.dats"
staload _(*anonymous*) = "prelude/DATS/array0.dats"

(* ****** ****** *)

local

datatype symbol = {n:pos} SYMBOL of (string n, int)

assume symbol_t = symbol

in

implement symbol_make_string_int (name, ind) = SYMBOL (name, ind)

implement symbol_index_get (sym) = let
  val+ SYMBOL (name, ind) = sym in ind end
// end of [symbol_index_get]
  
implement symbol_is_term (sym) = ~(symbol_is_nonterm sym)

implement symbol_is_nonterm (sym) = let
  val+ SYMBOL (name, _) = sym in char_isupper (name[0])
end // end of [symbol_is_nonterm]

implement eq_symbol_symbol (s1, s2) = let
  val+ SYMBOL (_, i1) = s1 and SYMBOL (_, i2) = s2
in
  if i1 = i2 then true else false
end // end of [eq_symbol_symbol]

implement compare_symbol_symbol (s1, s2) = let
  val+ SYMBOL (_, i1) = s1 and SYMBOL (_, i2) = s2
in
  compare_int_int (i1, i2)
end // end of [eq_symbol_symbol]

implement print_symbol (sym) = let
  val+ SYMBOL (name, _) = sym in print_string (name)
end // end of [print_symbol]

implement print_symbol_list (xs) = loop (xs, 0) where {
  fun loop (xs: List symbol_t, i: int): void = case+ xs of
    | list_cons (x, xs) => begin
        if i > 0 then print ", "; print_symbol x; loop (xs, i+1)
      end
    | list_nil () => ()
  // end of [loop]
} // end of [print_symbol_list]

end // end of [local]

(* ****** ****** *)

local

assume symbolset_t = List (symbol_t)

in

implement symbolset_nil = list_nil ()

implement symbolset_sing (x) = list_cons (x, list_nil ())

implement symbolset_ismem (xs, x) = f (xs, x) where {
  fun f (xs: List symbol_t, x: symbol_t): bool =
    case+ xs of
    | list_cons (x1, xs1) => if x = x1 then true else f (xs1, x)
    | list_nil () => false
  // end of [f]
} // end if [symbolset_ismem]

implement symbolset_add_flag (xs, x0, flag) = f (xs, x0, flag) where {
  fun f (xs: List symbol_t, x0: symbol_t, flag: &int): List symbol_t =
    case+ xs of
    | list_cons (x1, xs1) => begin
      case+ compare (x0, x1) of
      | ~1 => (flag := flag + 1; list_cons (x0, xs))
      |  1 => list_cons (x1, f (xs1, x0, flag))
      |  0 => xs
      end // end of [list_cons]
    | list_nil () => (flag := flag + 1; list_cons (x0, xs))
  // end of [f]      
} // end of [symbolset_add_flag]

implement symbolset_union_flag (xs, ys, flag) = f (xs, ys, flag) where {
  fun f (xs: List symbol_t, ys: List symbol_t, flag: &int): List symbol_t =
    case+ ys of
    | list_cons (y1, ys1) => let
        val xs = symbolset_add_flag (xs, y1, flag) in f (xs, ys1, flag)
      end // end of [list_cons]
    | list_nil () => xs
} // end of [symbolset_union_flag]

implement print_symbolset (xs) = begin
  print "{ "; print_symbol_list (xs); print " }"
end // end of [print_symbolset]

end // end of [local]

(* ****** ****** *)

implement print_rule (r) = let
  val+ RULE (name, S, alpha) = r
  val sz = int_of_size (array0_size alpha)
  var i: int // uninitialized
in
  print_symbol S;
  print "\t->\t";
  for (i := 0 ; i < sz; i := i+1)
    (if i > 0 then print ", "; print_symbol alpha[i]);
  printf ("\t(%s)", @(name))
end // end of [print_rule]

(* ****** ****** *)

implement print_grammar (G) = let
  val () = begin
    print "terminals: "; print_symbol_list G.termlst; print_newline ()
  end // end of [val]
  val () = begin
    print "nonterminals: "; print_symbol_list G.nontermlst; print_newline ()
  end // end of [val]
  val () = begin
    print "the production rules:\n"; loop (G.rules)
  end where {
    fun loop (rs: List rule): void = case+ rs of
      | list_cons (r, rs) => begin
          print_rule r; print_newline (); loop (rs)
        end
      | list_nil () => ()
    // end of [loop]
  } // end of [val]
in
  // empty
end // end of [print_grammar]

(* ****** ****** *)

#define NSYMBOL_MAX 2048 // ; terminal [0, 256); nonterminal: [256, MAX)

(* ****** ****** *)

val NULLABLEarr
  : array (bool, NSYMBOL_MAX) = array_make_elt (NSYMBOL_MAX, false)

val FIRSTSETarr
  : array (symbolset_t, NSYMBOL_MAX) = array_make_elt (NSYMBOL_MAX, symbolset_nil)

val FOLLOWSETarr
  : array (symbolset_t, NSYMBOL_MAX) = array_make_elt (NSYMBOL_MAX, symbolset_nil)

(* ****** ****** *)

implement symbol_is_nullable (sym) = let
  val ind = symbol_index_get sym
  val ind = int1_of_int ind
  val () = assert (ind >= 0)
  val () = assert (ind < NSYMBOL_MAX)
in
  NULLABLEarr[ind]
end // end of [symbol_is_nullable]

implement symbol_isnot_nullable (sym) = ~(symbol_is_nullable sym)

implement symbol_nullable_set (sym, v) = let
  val ind = symbol_index_get sym
  val ind = int1_of_int ind
  val () = assert (ind >= 0)
  val () = assert (ind < NSYMBOL_MAX)
in
  NULLABLEarr[ind] := v
end // end of [symbol_nullable_set]

(* ****** ****** *)

implement symbol_FIRSTSET_get (sym) = let
  val ind = symbol_index_get sym
  val ind = int1_of_int ind
  val () = assert (ind >= 0)
  val () = assert (ind < NSYMBOL_MAX)
in
  FIRSTSETarr[ind]
end // end of [symbol_FIRSTSET_get]

implement symbol_FIRSTSET_set (sym, v) = let
  val ind = symbol_index_get sym
  val ind = int1_of_int ind
  val () = assert (ind >= 0)
  val () = assert (ind < NSYMBOL_MAX)
in
  FIRSTSETarr[ind] := v
end // end of [symbol_FIRSTSET_set]

(* ****** ****** *)

implement symbol_FOLLOWSET_get (sym) = let
  val ind = symbol_index_get sym
  val ind = int1_of_int ind
  val () = assert (ind >= 0)
  val () = assert (ind < NSYMBOL_MAX)
in
  FOLLOWSETarr[ind]
end // end of [symbol_FOLLOWSET_get]

implement symbol_FOLLOWSET_set (sym, v) = let
  val ind = symbol_index_get sym
  val ind = int1_of_int ind
  val () = assert (ind >= 0)
  val () = assert (ind < NSYMBOL_MAX)
in
  FOLLOWSETarr[ind] := v
end // end of [symbol_FOLLOWSET_set]

(* ****** ****** *)

fn test_nullability
  (alpha: array0 symbol_t, i0: int, n: int): bool = let
  fun loop (i: int):<cloref1> bool =
    if i < i0 + n then
      if symbol_is_nullable (alpha[i]) then loop (i+1) else false
    else begin
      true // loop returns
    end // end of [if]
in
  loop (i0)
end // end of [test_nullability]

fun process_rule (r: rule, flag: &int): void = let
(*
  val () = begin
    print "process_rule: r = "; print_rule r; print_newline ()
  end // end of [val]
*)
  val+ RULE (name, x0, alpha) = r
  val k = int_of_size (array0_size alpha)
  // part 1:
  val () = if symbol_isnot_nullable (x0) then
    if test_nullability (alpha, 0, k) then let
      val () = flag := flag + 1 in symbol_nullable_set (x0, true)
    end // end of [if]
  // end of [if]
  // part 2:
  val () = loop1 (0, flag) where {
    fun loop1 (i: int, flag: &int):<cloref1> void = let
      // empty
    in
      if i < k then let
        val yi = alpha[i]
        val test = test_nullability (alpha, 0, i)
        // part 2.1:
        val () = if test then let
          val flag0 = flag
          val fstset_x0 = symbol_FIRSTSET_get (x0);
          val fstset_yi = symbol_FIRSTSET_get (yi)
          val fstset_x0_new = symbolset_union_flag (fstset_x0, fstset_yi, flag)
        in
          if flag > flag0 then symbol_FIRSTSET_set (x0, fstset_x0_new)
        end // end of [if]
        // part 2.2:
        val test = test_nullability (alpha, i+1, k-i-1)
        val () = if test then let
          val flag0 = flag
          val folset_x0 = symbol_FOLLOWSET_get (x0)
          val folset_yi = symbol_FOLLOWSET_get (yi)
          val folset_yi_new = symbolset_union_flag (folset_yi, folset_x0, flag)
        in
          if flag > flag0 then symbol_FOLLOWSET_set (yi, folset_yi_new)
        end
        // part 2.3:
        val () = loop2 (i, i+1, flag) where {
          fun loop2
            (i: int, j: int, flag: &int):<cloref1> void = let
          in
            if j < k then let
              val test = test_nullability (alpha, i+1, j-i-1)
              val () = if test then let
                val flag0 = flag
                val yj = alpha[j]
                val folset_yi = symbol_FOLLOWSET_get (yi)
                val fstset_yj = symbol_FIRSTSET_get (yj)
                val folset_yi_new = symbolset_union_flag (folset_yi, fstset_yj, flag)
              in
                if flag > flag0 then symbol_FOLLOWSET_set (yi, folset_yi_new)
              end // end of [val]
            in
              loop2 (i, j+1, flag)
            end // end of [if]
          end // end of [loop2]
        } // end of [val]
      in
        loop1 (i+1, flag)
      end else begin
        () // loop1 exits
      end // end of [if]
    end // end of [loop1]
  } // end of [val]
in
  // empty
end // end of [process_rule]

fun process_rules (rs: List rule, flag: &int): void =
  case+ rs of
  | list_cons (r, rs) => (process_rule (r, flag); process_rules (rs, flag))
  | list_nil () => ()
// end of [process_rules]

(* ****** ****** *)
 
implement compute_nullable_first_follow_tables (G) = let
  var flag: int? // uninitialized
  val () = loop (G.termlst) where {
    fun loop (xs: List symbol_t): void = case+ xs of
      | list_cons (x, xs) => let
          val () = symbol_FIRSTSET_set (x, symbolset_sing x)
        in
          loop (xs)
        end // end of [list_cons]
      | list_nil () => ()
    // end of [loop]
  } // end of [val]
in
  while (true) let
    val () = flag := 0; val () = process_rules (G.rules, flag)
  in
    if flag = 0 then break
  end // end of [while]
end // end of [process_grammar]

(* ****** ****** *)

implement print_nullable_first_follow_tables (G) = loop (G.nontermlst) where {
  fun loop (xs: List symbol_t): void = case+ xs of
    | list_cons (x, xs) => let
        val () = print_symbol x
        val () = print_string ": "
        val v = symbol_is_nullable x
        val () = print_bool v
        val v = symbol_FIRSTSET_get x
        val () = (print "; "; print_symbolset v)
        val v = symbol_FOLLOWSET_get x
        val () = (print "; "; print_symbolset v)
        val () = print_newline ()
      in
        loop (xs)
      end // end of [list_cons]
    | list_nil () => ()
  // end of [loop]
} // end of [output]

(* ****** ****** *)

(* end of [grammar.dats] *)