structure Typechecker = struct structure A = Absyn structure E = Env structure EX = Expressions structure S = Symbol structure TY = Types val err = Error.err val errPos = Error.errPos type ty = TY.ty type tenv = ty S.table fun findTy (tenv: tenv, tyid: S.symbol, pos: Pos.T): ty = case E.tlook (tenv, tyid) of NONE => errPos pos ("unknow type: " ^ S.name tyid) | SOME ty => ty fun tyLinkRemove (ty: ty): ty = case ty of TY.NAME (tyid, rot) => (case !rot of SOME ty => let val ty = tyLinkRemove ty val _ = (rot := SOME ty) in ty end | NONE => err ("impossible: tyLinkRemove: no definition for " ^ S.name tyid)) | ty => ty fun labelCheck (id, []) = () | labelCheck (id, {name= id', typ= _, pos= pos} :: tfs) = if S.eq (id, id') then errPos pos ("The label occurs repeatedly: " ^ S.name id) else labelCheck (id', tfs) fun transTy (tenv: tenv, t: A.ty): ty = case t of A.NameTy (tyid, pos) => findTy (tenv, tyid, pos) | A.RecordTy (tfs) => let val tfs = ListMergeSort.sort (fn (tf, tf') => String.> (S.name (#name tf), S.name (#name tf'))) tfs val _ = case tfs of [] => () | (tf :: tfs) => labelCheck (#name tf, tfs) val idtys = List.map (transTyField tenv) tfs in TY.RECORD (idtys, ref ()) end | A.ArrayTy (tyid, pos) => let val ty = findTy (tenv, tyid, pos) in TY.ARRAY (ty, ref ()) end and transTyField tenv {name= name, typ= tyid, pos= pos}: S.symbol * ty = let val ty = findTy (tenv, tyid, pos) in (name, ty) end fun transOp (A.PlusOp) = EX.PlusOp | transOp (A.MinusOp) = EX.MinusOp | transOp (A.TimesOp) = EX.TimesOp | transOp (A.DivideOp) = EX.DivideOp | transOp (A.GtOp) = EX.GtOp | transOp (A.GeOp) = EX.GeOp | transOp (A.LtOp) = EX.LtOp | transOp (A.LeOp) = EX.LeOp | transOp (A.AndOp) = EX.AndOp | transOp (A.OrOp) = EX.OrOp | transOp _ = Error.err "Expressions: transOp: equality operator" type venv = E.enventry S.table type var = EX.var type exp = EX.exp type dec = EX.dec fun checkTy (ty, ty', pos): unit = if TY.equal (ty, ty') then () else errPos pos ("type mismatch: " ^ TY.ty2str ty ^ " <> " ^ TY.ty2str ty') val arrayBoundCheckFlag: bool ref = ref true fun setArrayBoundCheckFlag (b: bool): unit = (arrayBoundCheckFlag := b) fun transVar (tenv: tenv, venv: venv, lp: bool, v: A.var): (exp -> exp) * var * ty = case v of A.SimpleVar (id, pos) => (case E.vlook (venv, id) of NONE => errPos pos ("undeclared variable: " ^ S.name id) | SOME (E.VarEntry {ty= ty}) => (fn hole => hole, EX.SimpleVar id, ty) | SOME _ => errPos pos ("a function type for the variable: " ^ S.name id)) | A.FieldVar (v, id, pos) => let val (ctx, v, ty) = transVar (tenv, venv, lp, v) in case tyLinkRemove ty of recTy as TY.RECORD (idtys, _) => (case Util.listAssoc' idtys id of NONE => errPos pos ("transVar: FieldVar: nonexistent label: " ^ S.name id) | SOME (n, ty) => (ctx, EX.SelectVar (v, n), ty)) | _ => errPos pos ("transVar: FieldVar: not a record type: " ^ TY.ty2str ty) end | A.SubscriptVar (v, e, pos) => let val (ctx, v, ty) = transVar (tenv, venv, lp, v) val e = transExpCkInt (tenv, venv, lp, e, pos) in case tyLinkRemove ty of TY.ARRAY (ty, _) => if !arrayBoundCheckFlag then let val ind = S.newSymbol () val indExp = EX.VarExp (EX.SimpleVar ind) in case e of EX.VarExp (EX.SimpleVar _) => (fn hole => ctx (EX.IfThenElseExp (EX.OpExp (e, EX.LtOp, EX.IntExp 0), EX.subscriptError, EX.IfThenElseExp (EX.OpExp (e, EX.GeOp, EX.VarExp (EX.SelectVar (v, 1))), EX.subscriptError, hole))), EX.SubscriptVar (EX.SelectVar (v, 0), e), ty) | _ => (fn hole => ctx (EX.LetExp ([EX.VarDec (ind, e, false)], EX.IfThenElseExp (EX.OpExp (indExp, EX.LtOp, EX.IntExp 0), EX.subscriptError, EX.IfThenElseExp (EX.OpExp (indExp, EX.GeOp, EX.VarExp (EX.SelectVar (v, 1))), EX.subscriptError, hole)))), EX.SubscriptVar (EX.SelectVar (v, 0), indExp), ty) end else (ctx, EX.SubscriptVar (v, e), ty) | _ => errPos pos ("transVar: SubscriptVar: not an array type: " ^ TY.ty2str ty) end and transVarCk (tenv: tenv, venv: venv, lp: bool, v: A.var, ty: TY.ty, pos: Pos.T) : (exp -> exp) * var = let val (ctx, v, ty') = transVar (tenv, venv, lp, v) in checkTy (ty', ty, pos); (ctx, v) end and transVarCkInt (tenv: tenv, venv: venv, lp: bool, v: A.var, pos: Pos.T): (exp -> exp) * var = transVarCk (tenv, venv, lp, v, TY.INT, pos) and transExp (tenv: tenv, venv: venv, lp: bool, e: A.exp): exp * ty = case e of A.VarExp v => let val (ctx, v, ty) = transVar (tenv, venv, lp, v) in (ctx (EX.VarExp v), ty) end | A.NilExp => (EX.NilExp, TY.NIL) | A.IntExp i => (EX.IntExp i, TY.INT) | A.StringExp (s, _) => (EX.StringExp s, TY.STRING) | A.CallExp {func= fid, args= es, pos= pos} => (case E.vlook (venv, fid) of NONE => errPos pos ("undeclared function symbol: " ^ S.name fid) | SOME (E.FunEntry {formals= tys, result= ty}) => let val es = transExpListCk (tenv, venv, lp, es, tys, pos) in (EX.CallExp (fid, es), ty) end | SOME _ => errPos pos ("transExp: CallExp: not a function type")) | A.OpExp {left= left, oper= oper, right= right, pos= pos} => (case oper of (A.PlusOp | A.MinusOp | A.TimesOp | A.DivideOp) => let val left = transExpCkInt (tenv, venv, lp, left, pos) val oper = transOp oper val right = transExpCkInt (tenv, venv, lp, right, pos) in (EX.OpExp (left, oper, right), TY.INT) end | (A.LtOp | A.LeOp | A.GtOp | A.GeOp) => let val left = transExpCkInt (tenv, venv, lp, left, pos) val oper = transOp oper val right = transExpCkInt (tenv, venv, lp, right, pos) in (EX.OpExp (left, oper, right), TY.INT) end | (A.AndOp | A.OrOp) => let val left = transExpCkInt (tenv, venv, lp, left, pos) val oper = transOp oper val right = transExpCkInt (tenv, venv, lp, right, pos) in (EX.OpExp (left, oper, right), TY.INT) end | A.EqOp => let val (left, ty1) = transExp (tenv, venv, lp, left) val (right, ty2) = transExp (tenv, venv, lp, right) val ty1 = tyLinkRemove ty1 val ty2 = tyLinkRemove ty2 in case (ty1, ty2) of (TY.INT, TY.INT) => (EX.OpExp (left, EX.IntEq, right), TY.INT) | (TY.STRING, TY.STRING) => (EX.OpExp (left, EX.StrEq, right), TY.INT) | (TY.RECORD _, TY.NIL) => (EX.OpExp (left, EX.TupEq, right), TY.INT) | (TY.NIL, TY.RECORD _) => (EX.OpExp (left, EX.TupEq, right), TY.INT) | (_, _) => errPos pos ("no equality is supported on types: " ^ TY.ty2str ty1 ^ " and " ^ TY.ty2str ty2) end | A.NeqOp => let val (left, ty1) = transExp (tenv, venv, lp, left) val (right, ty2) = transExp (tenv, venv, lp, right) val ty1 = tyLinkRemove ty1 val ty2 = tyLinkRemove ty2 in case (ty1, ty2) of (TY.INT, TY.INT) => (EX.OpExp (left, EX.IntNeq, right), TY.INT) | (TY.STRING, TY.STRING) => (EX.OpExp (left, EX.StrNeq, right), TY.INT) | (TY.RECORD _, TY.NIL) => (EX.OpExp (left, EX.TupNeq, right), TY.INT) | (TY.NIL, TY.RECORD _) => (EX.OpExp (left, EX.TupNeq, right), TY.INT) | (_, _) => errPos pos ("no equality is supported on types: " ^ TY.ty2str ty1 ^ " and " ^ TY.ty2str ty2) end) | A.ArrayExp {typ= tyid, size= size, init= init, pos= pos} => let val ty = findTy (tenv, tyid, pos) in case tyLinkRemove ty of arrayTy as TY.ARRAY (elemTy, _) => let val size = transExpCkInt (tenv, venv, lp, size, pos) val init = transExpCk (tenv, venv, lp, init, elemTy, pos) in if !arrayBoundCheckFlag then case size of EX.VarExp (EX.SimpleVar _) => (EX.TupleExp [EX.ArrayExp (size, init), size], arrayTy) | _ => let val sz = S.newSymbol () val szExp = EX.VarExp (EX.SimpleVar sz) in (EX.LetExp ([EX.VarDec (sz, size, false)], EX.TupleExp [EX.ArrayExp (szExp, init), szExp]), arrayTy) end else (EX.ArrayExp (size, init), arrayTy) end | _ => errPos pos ("transExp: ArrayExp: not an array type: " ^ TY.ty2str ty) end | A.RecordExp {fields= flds, typ=tyid, pos= pos} => let val ty = findTy (tenv, tyid, pos) in case tyLinkRemove ty of recTy as TY.RECORD (idtys, _) => let val es = transFieldListCk (tenv, venv, lp, flds, idtys, pos) in (EX.TupleExp es, recTy) end | _ => errPos pos ("transExp: RecordExp: not a record type: " ^ S.name tyid) end | A.AssignExp {var= v, exp= e, pos= pos} => let val (ctx, v, ty) = transVar (tenv, venv, lp, v) val e = transExpCk (tenv, venv, lp, e, ty, pos) in (ctx (EX.AssignExp (v, e)), TY.UNIT) end | A.SeqExp eps => let fun aux ([], es, ty) = (List.rev es, ty) | aux ((e, p) :: eps, es, ty) = let val (e, ty) = transExp (tenv, venv, lp, e) in aux (eps, e :: es, ty) end val (es, ty) = aux (eps, [], TY.UNIT) in (EX.SeqExp (es), ty) end | A.IfExp {test= test, then'= e1, else'= oe2, pos= pos} => let val test = transExpCkInt (tenv, venv, lp, test, pos) in case oe2 of NONE => let val e1 = transExpCkUnit (tenv, venv, lp, e1, pos) in (EX.IfThenExp (test, e1), TY.UNIT) end | SOME e2 => let val (e1, ty) = transExp (tenv, venv, lp, e1) val e2 = transExpCk (tenv, venv, lp, e2, ty, pos) in (EX.IfThenElseExp (test, e1, e2), ty) end end | A.BreakExp (pos) => if lp then (EX.BreakExp, TY.UNIT) else errPos pos "break statment is not within a loop" | A.WhileExp {test= test, body= body, pos= pos} => let val test = transExpCkInt (tenv, venv, lp, test, pos) val body = transExpCkUnit (tenv, venv, true, body, pos) in (EX.WhileExp (test, body), TY.UNIT) end | A.ForExp {var= id, escape= esc, lo= lo, hi= hi, body= body, pos= pos} => let val lo = transExpCkInt (tenv, venv, lp, lo, pos) val hi = transExpCkInt (tenv, venv, lp, hi, pos) val body = let val venv = S.enter (venv, id, E.VarEntry {ty= TY.INT}) in transExpCkUnit (tenv, venv, true, body, pos) end in (EX.for2while (id, esc, lo, hi, body), TY.UNIT) end | A.LetExp {decs= ds, body= body, pos= pos} => let val (tenv, venv, ds) = transDecs (tenv, venv, ds) val (body, ty) = transExp (tenv, venv, lp, body) in (EX.LetExp (ds, body), ty) end and transExpCk (tenv: tenv, venv: venv, lp: bool, e: A.exp, ty: TY.ty, pos: Pos.T): exp = let val (e, ty') = transExp (tenv, venv, lp, e) in checkTy (ty', ty, pos); e end and transExpListCk (tenv: tenv, venv: venv, lp: bool, es: A.exp list, tys: TY.ty list, pos: Pos.T) : exp list = let fun aux ([], [], res) = List.rev res | aux (e :: es, ty :: tys, res) = let val e = transExpCk (tenv, venv, lp, e, ty, pos) in aux (es, tys, e :: res) end | aux ([], _, _) = errPos pos ("too few function arguments") | aux (_, [], _) = errPos pos ("too many function arguments") in aux (es, tys, []) end and transExpCkInt (tenv: tenv, venv: venv, lp: bool, e: A.exp, pos: Pos.T): exp = transExpCk (tenv, venv, lp, e, TY.INT, pos) and transExpCkUnit (tenv: tenv, venv: venv, lp: bool, e: A.exp, pos: Pos.T): exp = transExpCk (tenv, venv, lp, e, TY.UNIT, pos) and transFieldListCk (tenv: tenv, venv: venv, lp: bool, flds: (S.symbol * A.exp * Pos.T) list, idtys: (S.symbol * TY.ty) list, pos: Pos.T): exp list = let val flds = ListMergeSort.sort (fn ((id, _, _), (id', _, _)) => String.> (S.name id, S.name id')) flds fun aux ([], [], res) = List.rev res | aux ((id, ty) :: idtys, (id', e', pos') :: flds, res) = if S.eq (id, id') then let val e' = transExpCk (tenv, venv, lp, e', ty, pos') in aux (idtys, flds, e' :: res) end else errPos pos ("transFieldListCk: no field with label: " ^ S.name id) | aux ([], _, _) = errPos pos ("transFieldListCk: too many fields") | aux (_, [], _) = errPos pos ("transFieldListCk: too few fields") in aux (idtys, flds, []) end and transDec (tenv: tenv, venv: venv, d: A.dec): tenv * venv * dec option = case d of A.FunctionDec fds => transFunctionDecList (tenv, venv, fds) | A.VarDec {name= id, escape= esc, typ= otp, init= init, pos= pos} => (case otp of NONE => let val (init, ty) = transExp (tenv, venv, false, init) val venv = S.enter (venv, id, E.VarEntry {ty= ty}) in (tenv, venv, SOME (EX.VarDec (id, init, !esc))) end | SOME (tyid, typos) => let val ty = findTy (tenv, tyid, typos) val init = transExpCk (tenv, venv, false, init, ty, pos) val venv = S.enter (venv, id, E.VarEntry {ty= ty}) in (tenv, venv, SOME (EX.VarDec (id, init, !esc))) end) | A.TypeDec idtps => let fun aux ([], rotts, tenv) = (rotts, tenv) | aux ({name= id, ty= t, pos= p} :: idtps, rotts, tenv) = let val rot = ref NONE val nameTy = TY.NAME (id, rot) val tenv = S.enter (tenv, id, nameTy) in aux (idtps, (rot, t) :: rotts, tenv) end val (rotts, tenv) = aux (idtps, [], tenv) fun aux' [] = () | aux' ((rot, t) :: rotts) = let val ty = transTy (tenv, t) val _ = (rot := SOME ty) in aux' rotts end val _ = aux' rotts in (tenv, venv, NONE) end and transFunctionDecList (tenv: tenv, venv: venv, fds: A.fundec list) : tenv * venv * dec option = let fun aux ([], bats, venv) = (List.rev bats, venv) | aux (fd :: fds, bats, venv) = let val {name= fid, params= params, result= result, body= body, pos= pos} = fd val (tys, idtys, idescs) = let fun aux ([], tys, idtys, idescs) = (List.rev tys, List.rev idtys, List.rev idescs) | aux (fld :: flds, tys, idtys, idescs) = let val {name= id, escape= esc, typ= tyid, pos= pos} = fld val ty = findTy (tenv, tyid, pos) in aux (flds, ty :: tys, (id, ty) :: idtys, (id, !esc) :: idescs) end in aux (params, [], [], []) end val result = case result of NONE => TY.UNIT | SOME (tyid, typos) => findTy (tenv, tyid, typos) val ent = E.FunEntry {formals= tys, result= result} val venv = S.enter (venv, fid, ent) in aux (fds, (fid, idtys, idescs, body, result, pos) :: bats, venv) end val (bats, venv) = aux (fds, [], venv) fun aux' ([], res) = res | aux' ((fid, idtys, idescs, body, result, pos) :: bats, res) = let val venv = List.foldr (fn ((id, ty), venv) => S.enter (venv, id, E.VarEntry {ty= ty})) venv idtys val body = transExpCk (tenv, venv, false, body, result, pos) in aux' (bats, (fid, idescs, body) :: res) end val fds = aux' (bats, []) in (tenv, venv, SOME (EX.FunDec fds)) end and transDecs (tenv: tenv, venv: venv, ds: A.dec list): tenv * venv * dec list = let fun aux (tenv, venv, [], res) = (tenv, venv, List.rev res) | aux (tenv, venv, d :: ds, res) = let val (tenv, venv, od) = transDec (tenv, venv, d) in case od of NONE => aux (tenv, venv, ds, res) | SOME d => aux (tenv, venv, ds, d :: res) end in aux (tenv, venv, ds, []) end fun transProg (e: A.exp): exp * ty = transExp (S.empty, S.empty, false, e) end