module AdaptiveSearch where import List ( nub, sortBy, group, (\\) ) import Maybe ( (>>-), mapMaybe ) type Coverage = [Int] type Ratio = (Int,Int) zero, one :: Ratio zero = (0,1) one = (1,1) plus, times :: Ratio -> Ratio -> Ratio plus (n1,d1) (n2,d2) = (n1*d2+n2*d1,d1*d2) times (n1,d1) (n2,d2) = (n1*n2,d1*d2) floor :: Ratio -> Int floor (n,d) = n `div` d distribute :: Int -> [Ratio] -> [Int] distribute l rs = map (\r -> floor (times (l,1) (times (d,n) r))) rs where (n,d) = foldr plus zero rs data CoverTree a = Leaf (a,Coverage) | Branch Coverage [(Ratio,CoverTree a)] coverTree :: SearchTree (a,Coverage) -> CoverTree a coverTree (Value x) = Leaf x coverTree (Choice ts) = Branch [] (zip (repeat one) (map coverTree ts)) coverTree Fail = Branch [] [] adaptiveSearch :: SearchTree (a,Coverage) -> [(a,Coverage)] adaptiveSearch = iter 1 . coverTree iter :: Int -> CoverTree a -> [(a,Coverage)] iter l t = xs ++ maybe [] (iter (2*l)) mt where (xs,_,mt) = search l t search :: Int -> CoverTree a -> ([(a,Coverage)],Int,Maybe (CoverTree a)) search limit (Leaf x) = ([x],limit,Nothing) search limit t@(Branch cov ts) | limit == 0 = ([],0,Just t) | otherwise = (xs,l,branch cov cts) where limits = distribute (limit-1) (map fst ts) lts = sortBy (\ (l1,_) (l2,_) -> l1 >= l2) (zip limits ts) (xs,l,cts) = searchAll 0 lts searchAll :: Int -> [(Int,(Ratio,CoverTree a))] -> ([(a,Coverage)],Int,[(Coverage,Maybe (Ratio,CoverTree a))]) searchAll _ [] = ([],0,[]) searchAll n ((l,(r,t)):lts) = (xs++ys,ly,(cov,mt >>- (\t' -> Just (r,t'))):ts) where (xs,lx,mt) = search (n+l) t (ys,ly,ts) = searchAll lx lts cov = nub (concatMap snd xs) branch :: Coverage -> [(Coverage,Maybe (Ratio,CoverTree a))] -> Maybe (CoverTree a) branch cov cts | null rts = Nothing | otherwise = Just (Branch (nub (concat (cov:map fst cts))) rts) where rts = mapMaybe (\ (c,mt) -> mt >>- update cov c) cts -- use for pruning later update :: Coverage -> Coverage -> (Ratio,CoverTree a) -> Maybe (Ratio,CoverTree a) update cov c ((new,old),t) = Just ((new+progress,old+length c-progress),t) where progress = length (c\\cov) exampleTree :: Coverage -> SearchTree (Int,Coverage) exampleTree cov@(n:_) = Choice [exampleTree cov,Value (n,cov),exampleTree (n+1:cov)] benchmark :: Int -> (SearchTree (Int,Coverage) -> [(Int,Coverage)]) -> IO () benchmark n allValues = mapIO_ put . group . sortBy (<=) . map fst . take n . allValues \$ exampleTree [0] where put ms@(m:_) = putStrLn (show m ++ ":\t" ++ show (length ms))