module SearchTree ( insert, lookupSearchTree ) where

import Maybe

data SearchTree key val
  = Leaf | Branch (SearchTree key val) key val (SearchTree key val)
 deriving (Eq, Show)

testTree = Branch (Branch (Branch Leaf 1 1 Leaf)
                      3 3 (Branch Leaf 5 5 Leaf))
              7 7 (Branch (Branch Leaf 9 9 Leaf)
                    11 11 (Branch Leaf 13 13 Leaf))

-- not used
insertC :: Ord key
        => key -> val -> SearchTree key val
        -> (SearchTree key val -> SearchTree key val)
        -> SearchTree key val
insertC key val Leaf c = c (Branch Leaf key val Leaf)
insertC key val (Branch left k v right) c
  | key == k  = c (Branch left key val right)
  | key <  k  = insertC key val left  (\l -> c $ Branch l k v right)
  | otherwise = insertC key val right (\r -> c $ Branch left k v r)

insert :: Ord key => key -> val -> SearchTree key val -> SearchTree key val
insert key val Leaf = Branch Leaf key val Leaf
insert key val t@(Branch left k v right)
  = insert' key val t k (Branch left key val right) id

insert' :: Ord key
        => key -> val -> SearchTree key val 
        -> key -> SearchTree key val 
        -> (SearchTree key val -> SearchTree key val)
        -> SearchTree key val
insert' key val Leaf k t c
  = if key == k then t else c (Branch Leaf key val Leaf)
insert' key val (Branch left k v right) j t c
  = if key < k
     then insert' key val left j t (\l -> c $ Branch l k v right)
     else insert' key val right k
            (c (Branch left key val right))
            (\r -> c $ Branch left k v r)


insertTest = insertTest1 && insertTest2

insertTest1 =
  ok $ searchTreeToList $ foldr (\n t -> insert n n t) testTree [0,2..14]
 where
  ok = (==map (\x -> (x,x)) [0..14])
insertTest2 = testTree == foldr (\n t -> insert n n t) testTree [1,3..13]


lookupSearchTree :: Ord key => key -> SearchTree key val -> Maybe val
lookupSearchTree _ Leaf = Nothing
lookupSearchTree key t@(Branch left k v right)
 = lookup' key t k v

lookup' :: Ord key => key -> SearchTree key val -> key -> val -> Maybe val
lookup' key Leaf x y = if key==x then Just y else Nothing
lookup' key (Branch left k v right) x y
  = if key < k then lookup' key left x y else lookup' key right k v


lookupTest = all isNothing (map (flip lookupSearchTree testTree) [0,2..14])
          && all isJust    (map (flip lookupSearchTree testTree) [1,3..13])

searchTreeToList :: SearchTree key val -> [(key,val)]
searchTreeToList st = foldSearchTree id (\l k v r -> l . ((k,v):) . r) st []

foldSearchTree :: b -> (b -> key -> val -> b -> b) -> SearchTree key val -> b
foldSearchTree leaf _ Leaf = leaf
foldSearchTree leaf branch (Branch left key val right)
  = branch (fold left) key val (fold right)
 where
  fold = foldSearchTree leaf branch

testAll = insertTest && lookupTest