module SearchTree (

  insert, lookupSearchTree, delete, listToSearchTree, searchTreeToList,
  mapSearchTree, foldSearchTree

  ) where

data SearchTree key val 
  = Leaf | Branch (SearchTree key val) key val (SearchTree key val)
 deriving Eq

instance (Show key, Show val) => Show (SearchTree key val) where
  showsPrec _ = showsSearchTree 0

insert :: Ord key => key -> val -> SearchTree key val -> SearchTree key val
insert key val Leaf = Branch Leaf key val Leaf
insert key val (Branch left k v right)
  | key == k  = Branch left key val right
  | key <  k  = Branch (insert key val left) k v right
  | otherwise = Branch left k v (insert key val right)

insertTest = insertTest1 && insertTest2 && insertTest3
insertTest1 = insert 1 True (Branch Leaf 0 False Leaf)
           == Branch Leaf 0 False (Branch Leaf 1 True Leaf)
insertTest2 = insert 0 True (Branch Leaf 1 False Leaf)
           == Branch (Branch Leaf 0 True Leaf) 1 False Leaf
insertTest3 = insert 0 True (Branch Leaf 0 undefined Leaf)
           == Branch Leaf 0 True Leaf

lookupSearchTree :: Ord key => key -> SearchTree key val -> Maybe val
lookupSearchTree _ Leaf = Nothing
lookupSearchTree key (Branch left k v right)
  | key == k  = Just v
  | key <  k  = lookupSearchTree key left
  | otherwise = lookupSearchTree key right

lookupTest = lookupTest1 && lookupTest2 && lookupTest3
lookupTest1 = lookupSearchTree 1 (Branch undefined 0 undefined Leaf)
           == (Nothing::Maybe Bool)
lookupTest2 = lookupSearchTree 0 (Branch Leaf 1 undefined undefined)
           == (Nothing::Maybe Bool)
lookupTest3 = lookupSearchTree 0 (Branch undefined 0 True undefined)
           == Just True

delete :: Ord key => key -> SearchTree key val -> SearchTree key val
delete _ Leaf = Leaf
delete key (Branch left k v right)
  | key == k  = if isLeaf left then right
                 else let (newleft,maxkey,maxval) = exposeMax left
                       in Branch newleft maxkey maxval right
  | key < k   = Branch (delete key left) k v right
  | otherwise = Branch left k v (delete key right)

deleteTest = deleteTest1 && deleteTest2 && deleteTest3 && deleteTest4
deleteTest1 = delete 1 (Branch Leaf 0 False Leaf)
           == Branch Leaf 0 False Leaf
deleteTest2 = delete 0 (Branch Leaf 1 False Leaf)
           == Branch Leaf 1 False Leaf
deleteTest3 = delete 0 (Branch Leaf 0 undefined Leaf)
           == (Leaf::SearchTree Int Bool)
deleteTest4 = delete 2 (Branch (Branch Leaf 0 True (Branch Leaf 1 False Leaf))
                          2 undefined Leaf)
           == Branch (Branch Leaf 0 True Leaf) 1 False Leaf

isLeaf :: SearchTree key val -> Bool
isLeaf Leaf = True
isLeaf _ = False

-- partially defined on branches
exposeMax :: SearchTree key val -> (SearchTree key val,key,val)
exposeMax (Branch left key val Leaf) = (left,key,val)
exposeMax (Branch left key val right)
  = (Branch left key val newright, maxkey, maxval)
 where
  (newright,maxkey,maxval) = exposeMax right

exposeMaxTest = exposeMax (Branch Leaf 0 False (Branch Leaf 1 True Leaf))
             == (Branch Leaf 0 False Leaf,1,True)

listToSearchTree :: Ord key => [(key,val)] -> SearchTree key val
listToSearchTree = foldr (uncurry insert) Leaf

fromListTest = fromListTest1 && fromListTest2 && fromListTest3
fromListTest1 = listToSearchTree [(0,False),(0,undefined)]
             == Branch Leaf 0 False Leaf
fromListTest2 = listToSearchTree [(1,True),(0,False)]
             == Branch Leaf 0 False (Branch Leaf 1 True Leaf)
fromListTest3 = listToSearchTree [(0,False),(1,True)] 
             == Branch (Branch Leaf 0 False Leaf) 1 True Leaf

searchTreeToList :: Ord key => SearchTree key val -> [(key,val)]
searchTreeToList st = foldSearchTree id (\l k v r -> l . ((k,v):) . r) st []

toListTest = searchTreeToList (Branch Leaf 0 False Leaf)
          == [(0,False)]

foldSearchTree :: res -> (res -> key -> val -> res -> res)
               -> SearchTree key val -> res
foldSearchTree leaf _ Leaf = leaf
foldSearchTree leaf branch (Branch left key val right)
  = branch (fold left) key val (fold right)
 where
  fold = foldSearchTree leaf branch

foldTest = foldSearchTree Leaf Branch (Branch Leaf 0 False Leaf)
        == Branch Leaf 0 False Leaf

mapSearchTree :: (val1 -> val2) -> SearchTree key val1 -> SearchTree key val2
mapSearchTree f = foldSearchTree Leaf (\l k v r -> Branch l k (f v) r)

mapTest = mapSearchTree not (Branch Leaf 0 False Leaf)
       == Branch Leaf 0 True Leaf

testAll = insertTest && lookupTest && deleteTest && exposeMaxTest
       && fromListTest && toListTest && foldTest && mapTest

showsSearchTree :: (Show key, Show val)
                => Int -> SearchTree key val -> String -> String
showsSearchTree i Leaf = (indent i . (("-\n")++))
showsSearchTree i (Branch left key val right)
  | all isLeaf [left,right] = showsMe
  | otherwise = showsMe
              . showsSearchTree (i+2) left
              . showsSearchTree (i+2) right
 where
  showsMe = (indent i . (("- "++show key++"->"++show val++"\n")++))

indent :: Int -> String -> String
indent i = (replicate i ' '++)

