module SearchTree (

  insert, lookupSearchTree, delete, listToSearchTree, searchTreeToList,
  mapSearchTree, foldSearchTree

  ) where

data SearchTree a = Leaf | Branch (SearchTree a) Int a (SearchTree a)
 deriving (Eq, Show)

insert :: Int -> a -> SearchTree a -> SearchTree a
insert key val Leaf = Branch Leaf key val Leaf
insert key val (Branch left k v right)
  | key == k  = Branch left key val right
  | key <  k  = Branch (insert key val left) k v right
  | otherwise = Branch left k v (insert key val right)

insertTest = insertTest1 && insertTest2 && insertTest3
insertTest1 = insert 1 True (Branch Leaf 0 False Leaf)
           == Branch Leaf 0 False (Branch Leaf 1 True Leaf)
insertTest2 = insert 0 True (Branch Leaf 1 False Leaf)
           == Branch (Branch Leaf 0 True Leaf) 1 False Leaf
insertTest3 = insert 0 True (Branch Leaf 0 undefined Leaf)
           == Branch Leaf 0 True Leaf

lookupSearchTree :: Int -> SearchTree a -> Maybe a
lookupSearchTree _ Leaf = Nothing
lookupSearchTree key (Branch left k v right)
  | key == k  = Just v
  | key <  k  = lookupSearchTree key left
  | otherwise = lookupSearchTree key right

lookupTest = lookupTest1 && lookupTest2 && lookupTest3
lookupTest1 = lookupSearchTree 1 (Branch undefined 0 undefined Leaf)
           == (Nothing::Maybe Bool)
lookupTest2 = lookupSearchTree 0 (Branch Leaf 1 undefined undefined)
           == (Nothing::Maybe Bool)
lookupTest3 = lookupSearchTree 0 (Branch undefined 0 True undefined)
           == Just True

delete :: Int -> SearchTree a -> SearchTree a
delete _ Leaf = Leaf
delete key (Branch left k v right)
  | key == k  = if isLeaf left then right
                 else let (newleft,maxkey,maxval) = exposeMax left
                       in Branch newleft maxkey maxval right
  | key < k   = Branch (delete key left) k v right
  | otherwise = Branch left k v (delete key right)

deleteTest = deleteTest1 && deleteTest2 && deleteTest3 && deleteTest4
deleteTest1 = delete 1 (Branch Leaf 0 False Leaf)
           == Branch Leaf 0 False Leaf
deleteTest2 = delete 0 (Branch Leaf 1 False Leaf)
           == Branch Leaf 1 False Leaf
deleteTest3 = delete 0 (Branch Leaf 0 undefined Leaf)
           == (Leaf::SearchTree Bool)
deleteTest4 = delete 2 (Branch (Branch Leaf 0 True (Branch Leaf 1 False Leaf))
                          2 undefined Leaf)
           == Branch (Branch Leaf 0 True Leaf) 1 False Leaf

isLeaf :: SearchTree a -> Bool
isLeaf Leaf = True
isLeaf _ = False

-- partially defined on branches
exposeMax :: SearchTree a -> (SearchTree a,Int,a)
exposeMax (Branch left key val Leaf) = (left,key,val)
exposeMax (Branch left key val right)
  = (Branch left key val newright, maxkey, maxval)
 where
  (newright,maxkey,maxval) = exposeMax right

exposeMaxTest = exposeMax (Branch Leaf 0 False (Branch Leaf 1 True Leaf))
             == (Branch Leaf 0 False Leaf,1,True)

listToSearchTree :: [(Int,a)] -> SearchTree a
listToSearchTree = foldr (uncurry insert) Leaf

fromListTest = fromListTest1 && fromListTest2 && fromListTest3
fromListTest1 = listToSearchTree [(0,False),(0,undefined)]
             == Branch Leaf 0 False Leaf
fromListTest2 = listToSearchTree [(1,True),(0,False)]
             == Branch Leaf 0 False (Branch Leaf 1 True Leaf)
fromListTest3 = listToSearchTree [(0,False),(1,True)] 
             == Branch (Branch Leaf 0 False Leaf) 1 True Leaf

searchTreeToList :: SearchTree a -> [(Int,a)]
searchTreeToList st = foldSearchTree id (\l k v r -> l . ((k,v):) . r) st []

toListTest = searchTreeToList (Branch Leaf 0 False Leaf)
          == [(0,False)]

foldSearchTree :: b -> (b -> Int -> a -> b -> b) -> SearchTree a -> b
foldSearchTree leaf _ Leaf = leaf
foldSearchTree leaf branch (Branch left key val right)
  = branch (fold left) key val (fold right)
 where
  fold = foldSearchTree leaf branch

foldTest = foldSearchTree Leaf Branch (Branch Leaf 0 False Leaf)
        == Branch Leaf 0 False Leaf

mapSearchTree :: (a -> b) -> SearchTree a -> SearchTree b
mapSearchTree f = foldSearchTree Leaf (\l k v r -> Branch l k (f v) r)

mapTest = mapSearchTree not (Branch Leaf 0 False Leaf)
       == Branch Leaf 0 True Leaf

testAll = insertTest && lookupTest && deleteTest && exposeMaxTest
       && fromListTest && toListTest && foldTest && mapTest
