(* * This example shows that the insert operation maps a balanced * redblack tree into a balanced redblack tree. Also it increases * the size of the tree by at most one (note: the inserted key may * have already existed in the tree). *) (* 8 type annotations, which occupy about 20 lines *) structure RedBlackTree = struct type key = int type answer = key option datatype order = LESS | EQUAL | GREATER type 'a entry = int * 'a datatype 'a dict = Empty (* considered black *) | Black of 'a entry * 'a dict * 'a dict | Red of 'a entry * 'a dict * 'a dict typeref 'a dict of nat * nat * nat * nat (* The meaning of the four indices is: (color, black height, red height, size). * A balenced tree is one such that * (1) for every node in it, both of its sons are of the same black height. * (2) the red height of the tree is 0, which means that there exist * no consecutive red nodes. *) with Empty <| 'a dict(0, 0, 0, 0) | Black <| {cl:nat, cr:nat, bh:nat, sl:nat, sr:nat} 'a entry * 'a dict(cl, bh, 0, sl) * 'a dict(cr, bh, 0, sr) -> 'a dict(0, bh+1, 0, sl+sr+1) | Red <| {cl:nat, cr:nat, bh:nat, rhl:nat, rhr:nat, sl:nat, sr:nat} 'a entry * 'a dict(cl, bh, rhl, sl) * 'a dict(cr, bh, rhr, sr) -> 'a dict(1, bh, cl+cr+rhl+rhr, sl+sr+1) fun compare (s1,s2) = if s1 > s2 then GREATER else if s1 < s2 then LESS else EQUAL where compare <| int * int -> order fun('a) lookup dict key = let fun lk (Empty) = NONE | lk (Red tree) = lk' tree | lk (Black tree) = lk' tree where lk <| 'a dict -> answer and lk' ((key1, datum1), left, right) = (case compare(key,key1) of EQUAL => SOME(key1) | LESS => lk left | GREATER => lk right) where lk' <| 'a entry * 'a dict * 'a dict -> answer in lk dict end where loopup <| 'a dict -> key -> answer fun('a) restore_right(e, Red lt, Red (rt as (_,Red _,_))) = Red(e, Black lt, Black rt) | restore_right(e, Red lt, Red (rt as (_,_,Red _))) = Red(e, Black lt, Black rt) | restore_right(e, l as Empty, Red(re, Red(rle, rll, rlr), rr)) = Black(rle, Red(e, l, rll), Red(re, rlr, rr)) | restore_right(e, l as Black _, Red(re, Red(rle, rll, rlr), rr)) = Black(rle, Red(e, l, rll), Red(re, rlr, rr)) | restore_right(e, l as Empty, Red(re, rl, rr as Red _)) = Black(re, Red(e, l, rl), rr) | restore_right(e, l as Black _, Red(re, rl, rr as Red _)) = Black(re, Red(e, l, rl), rr) | restore_right(e, l, r as Red(_, Empty, Empty)) = Black(e, l, r) (* r must be a red/black tree *) | restore_right(e, l, r as Red(_, Black _, Black _)) = Black(e, l, r) (* r must be a red/black tree *) | restore_right(e, l, r as Black _) = Black(e, l, r) (* r must be a red/black tree *) where restore_right <| {cl:nat, cr:nat, bh:nat, rhr:nat, sl:nat, sr:nat | rhr <= 1} 'a entry * 'a dict(cl, bh, 0, sl) * 'a dict(cr, bh, rhr, sr) -> [c:nat | c <= 1 ] 'a dict(c, bh+1, 0, sl + sr + 1) fun('a) restore_left(e, Red (lt as (_,Red _,_)), Red rt) = Red(e, Black lt, Black rt) | restore_left(e, Red (lt as (_,_,Red _)), Red rt) = Red(e, Black lt, Black rt) | restore_left(e, Red(le, ll as Red _, lr), r as Empty) = Black(le, ll, Red(e, lr, r)) | restore_left(e, Red(le, ll as Red _, lr), r as Black _) = Black(le, ll, Red(e, lr, r)) | restore_left(e, Red(le, ll, Red(lre, lrl, lrr)), r as Empty) = Black(lre, Red(le, ll, lrl), Red(e, lrr, r)) | restore_left(e, Red(le, ll, Red(lre, lrl, lrr)), r as Black _) = Black(lre, Red(le, ll, lrl), Red(e, lrr, r)) | restore_left(e, l as Red(_, Empty, Empty), r) = Black(e, l, r) (* l must be a red/black tree *) | restore_left(e, l as Red(_, Black _, Black _), r) = Black(e, l, r) (* l must be a red/black tree *) | restore_left(e, l as Black _, r) = Black(e, l, r) (* l must be a red/black tree *) where restore_left <| {cl:nat, cr:nat, bh:nat, rhl:nat, sl:nat, sr:nat | rhl <= 1} 'a entry * 'a dict(cl, bh, rhl, sl) * 'a dict(cr, bh, 0, sr) -> [c:nat | c <= 1 ] 'a dict(c, bh+1, 0, sl + sr + 1) exception Item_Is_Found fun('a) insertExc (dict, entry as (key,datum)) = let (* val ins : 'a dict -> 'a dict inserts entry *) (* ins (Red _) may violate color invariant at root, having red height 1 *) (* ins (Black _) or ins (Empty) will be red/black tree *) (* ins preserves black height *) fun ins (Empty) = Red(entry, Empty, Empty) | ins (Red(entry1 as (key1, datum1), left, right)) = (case compare(key,key1) of EQUAL => raise(Item_Is_Found) | LESS => Red(entry1, ins left, right) | GREATER => Red(entry1, left, ins right)) | ins(Black(entry1 as (key1, datum1), left, right)) = (case compare(key,key1) of EQUAL => raise(Item_Is_Found) | LESS => restore_left(entry1, ins left, right) | GREATER => restore_right(entry1, left, ins right)) where ins <| {c:nat, bh:nat, s:nat} 'a dict(c, bh, 0, s) -> [nc:nat, nrh:nat | ((c = 0 /\ nrh = 0 /\ nc <= 1) \/ (c = 1 /\ nrh <= 1 /\ nc = 1))] 'a dict(nc, bh, nrh, s+1) in let val dict = ins dict in case dict of Red (t as (_, Red _, _)) => Black t (* re-color *) | Red (t as (_, _, Red _)) => Black t (* re-color *) | Red (t as (_, Black _, Black _)) => dict | Red (t as (_, Empty, Empty)) => dict | Black _ => dict end handle Item_Is_Found => dict end where insertExc <| {c:nat, bh:nat, s:nat} 'a dict(c, bh, 0, s) * 'a entry -> [nc:nat, nbh:nat, ns:nat | (nbh = bh \/ nbh = bh + 1) /\ (ns = s \/ ns = s+1)] 'a dict(nc, nbh, 0, ns) end