1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
module CheckSMT
  ( checkNonFailFunc, checkPreCon, checkPostCon
  ) where

import Control.Monad               ( unless )
import Control.Monad.IO.Class      ( liftIO )
import Control.Monad.Trans.Class   ( lift )
import Control.Monad.Trans.State   ( gets )
import Data.List                   ( find, intercalate, union, partition )
import Data.Maybe                  ( catMaybes, fromJust, mapMaybe )
import Solver.SMTLIB.Z3
import System.Directory            ( removeFile )
import System.IO                   ( hClose, hGetContents )
import System.IOExts               ( execCmd )

import Contract.Names              ( toPreCondName, toPreCondQName
                                   , toPostCondName, toPostCondQName )
import FlatCurry.Annotated.Goodies
import FlatCurry.Annotated.Pretty  ( ppTypeExp, ppQName )
import FlatCurry.Types             ( showQName )
import Language.SMTLIB.Goodies
import Language.SMTLIB.Pretty
import qualified Language.SMTLIB.Types as SMT
import Text.Pretty                 ( hsep, pPrint, pretty )

import Curry2SMT
import ESMT
import FlatCurry.Typed.Build
import FlatCurry.Typed.Names
import FlatCurry.Typed.Read        ( stripForall )
import FlatCurry.Typed.Types
import ToolOptions
import TransState
import VerifierState

---------------------------------------------------------------------------
-- Checks the satisfiability of the given assertion and checks the fail
-- condition if the assertion is satisfiable.
checkNonFailFunc :: String -> SMT.Term -> SMT.Term -> SMT.Term
                 -> TransStateM (Maybe Bool)
checkNonFailFunc scripttitle assertion impbindings imp =
  generateSMT scripttitle assertion impbindings imp
    >>= checkSMT evalNonFailFunc []
      (\name args _ -> "Call" ++ printCall name args ++ "fails")

-- Checks the satisfiability of the given assertion and checks the pre
-- condition if the assertion is satisfiable.
checkPreCon :: String -> SMT.Term -> SMT.Term -> SMT.Term -> String -> [Int]
            -> TransStateM (Maybe Bool)
checkPreCon scripttitle assertion impbindings imp mname mvars = do
  optcontract <- lift $ getOption optContract
  if optcontract > 1
    then
      generateSMT scripttitle assertion impbindings imp
        >>= checkSMT evalPreCon mvars (\name args margs -> "Call"
          ++ printCall mname margs ++ "violates " ++ toPreCondName name
          ++ " due to calling" ++ printCall name args)
    else return $ Just False

-- Checks the satisfiability of the given assertion and checks the post
-- condition if the assertion is satisfiable.
-- If it is satisfiable, just the script to prove it is returned,
-- other `Nothing` is returned.
checkPostCon :: String -> SMT.Term -> SMT.Term -> SMT.Term
             -> TransStateM (Maybe String)
checkPostCon scripttitle assertion impbindings imp = do
  optcontract <- lift $ getOption optContract
  if optcontract > 1
    then do
      smt <- generateSMT scripttitle assertion impbindings imp
      mbbool <- checkSMT evalPostCon []
                  (\name args _ -> "Call" ++ printCall name args ++
                                   "violates " ++ toPostCondName name)
                  smt
      return $
        maybe Nothing
              (\b -> if b
                       then Just $ "; proved by: z3 -smt2 <SMTFILE>\n\n" ++
                                   showSMT (smt ++ checkSatCommands)
                       else Nothing)
              mbbool
    else return Nothing

-- Generates a string representation of a function call.
printCall :: String -> [String] -> String
printCall name args = " '" ++ name ++ [' ' | not $ null args]
                          ++ unwords args ++ "' "

---------------------------------------------------------------------------
-- Generates a SMT script from the given assertion.
generateSMT :: String -> SMT.Term -> SMT.Term -> SMT.Term
            -> TransStateM ([SMT.Command])
generateSMT scripttitle assertion impbindings imp = do
  vartypes <- getVarTypes
  let (pretypes,usertypes) =
         partition ((== "Prelude") . fst)
                  (foldr union []
                    (map (tconsOfTypeExpr . (\(_,x,_) -> x)) vartypes))
      allsyms = catMaybes
                  (map (\n -> maybe Nothing Just (untransOpName n))
                       (map unqual
                         (allQIdsOfTerm (tand [assertion, impbindings, imp]))))
  unless (null allsyms) $ lift $ printWhenIntermediate $
    "Translating operations into SMT: " ++
    unwords (map showQName allsyms)
  smtfuncs <- lift $ funcs2SMT allsyms
  tdecls   <- mapM (lift . tdeclOf) usertypes
  let decls = map (maybe (error "Internal error: some datatype not found!") id)
                  tdecls
      smt1  = [ SMT.Comment scripttitle ] ++
              prelude ++
              concatMap preludeType2SMT (map snd pretypes) ++
              (if null decls
                 then []
                 else [ SMT.Comment "User-defined datatypes:" ] ++
                      map tdecl2SMT decls)
      smt2  = [ SMT.Comment "Free variables:" ] ++
              concatMap typedVar2SMT vartypes ++
              [ SMT.Comment "Boolean formula of assertion (known properties):"
              , sAssert assertion
              , SMT.Comment "Bindings of implication:"
              , sAssert impbindings
              , SMT.Comment "Assert negated implication:"
              , sAssert (tnot imp)
              ]
      smt   = unpoly $ (map Right smt1) ++ [Left smtfuncs] ++ (map Right smt2)
  lift $ printWhenAll $
    "SMT SCRIPT:\n" ++ (showWithLineNums $ showSMT $ smt ++ checkSatCommands)
  return smt

checkSatCommands :: [SMT.Command]
checkSatCommands =
  [ SMT.Comment "check satisfiability:"
  , SMT.CheckSat
  , SMT.Comment "if unsat, the implication is valid"
  ]

-- Checks the given SMT script for satisfiability and evaluates the returned
-- values with the given evaluation function if the solver returns satisfiable.
checkSMT :: (QName -> TypeExpr -> [[String]] -> IO [Bool]) -> [Int]
         -> (String -> [String] -> [String] -> String) -> [SMT.Command]
         -> TransStateM (Maybe Bool)
checkSMT eval mvars output smt = do
  vartypes <- getVarTypes
  let getvaluevars =
        mapMaybe
          (\(i, vartype, mn) -> case mn of
            Just (name, pos, funcindex) | pos /= 0 && funcindex /= 0 ->
              Just (tvar i, ((name, funcindex), vartype))
            _ -> Nothing
          )
          vartypes
      argtypes = groupPairs $ map snd getvaluevars
      mvars' = map tvar mvars
  timeout <- lift (getOption optTimeout >>= \to -> return $ "-T:" ++ show to)
  lift $ printWhenAll $ "CALLING Z3 (with options: -smt2 " ++ timeout ++ ")..."
  exNum <- lift $ getOption optExamples
  -- store SMT script, might be useful to check for errors
  lift $ evalOption optVerb (> 2) $ liftIO $
    writeFile "LATEST_SMT.smt" (showSMT smt)
  answer <- liftIO $
    evalSessions z3 { flags = ["-smt2", "-in", timeout] } defSMTOpts $
      solveAllSMTVars (union mvars' $ map fst getvaluevars) smt exNum
  case answer of
    Left  errs -> (lift $ printWhenAll $ pPrint $ hsep $ map pretty errs) >>
                  return Nothing
    Right vpss -> if null vpss
      then (lift $ printWhenAll "RESULT:\nunsat") >> return (Just True)
      else do
        lift $ printWhenAll "RESULT:\nsat"
        let mres = map (map (decodeTerm . snd) . filter ((`elem` mvars') . fst))
                   vpss
        allfuncs <- lift $ gets $ allFuncs . trInfo
        counterExamples <- liftIO $ checkFuncs eval allfuncs argtypes
          $ map (\(f, args) -> (f, zip args mres))
          $ groupPairs $ concatMap
              (groupPairs . (mapMaybe (\(var, val) ->
                (,) <$> (fst <$> lookup var getvaluevars)
                    <*> Just (decodeTerm val))))
              vpss
        lift $ printWhenStatus $ unlines $ map
          (\(name, args) -> unlines $ map
            (\(arg, marg) -> output name arg marg) args) counterExamples
        return $ Just False

-- SMT script for declaring Prelude datatypes.
prelude :: [SMT.Command]
prelude =
  [ SMT.Comment "disable model-based quantifier instantiation (avoid loops)"
  , SMT.SetOption $ SMT.OptAttr $ SMT.AVal (SMT.KW "smt.mbqi")
                                           (SMT.AVSym "false")
  , SMT.Comment "For polymorphic types:"
  , SMT.DeclareSort "TVar" 0
  , SMT.Comment "Unit type:"
  , SMT.DeclareDatatype "Unit" $ SMT.MT [SMT.Cons "unit" []]
  , SMT.Comment "Pair type:"
  , SMT.DeclareDatatype "Pair" $
      SMT.PT ["T1", "T2"]
             [ SMT.Cons "mk-pair"
                        [ SMT.SV "first"  $ SMT.SComb "T1" []
                        , SMT.SV "second" $ SMT.SComb "T2" []
                        ]
             ]
  , SMT.Comment "For functional types:"
  , SMT.DeclareDatatype "Func" $
      SMT.PT ["T1", "T2"]
             [ SMT.Cons "mk-func"
                        [ SMT.SV "argument" $ SMT.SComb "T1" []
                        , SMT.SV "result"   $ SMT.SComb "T2" []
                        ]
             ]
  , SMT.Comment "Maybe type:"
  , SMT.DeclareDatatype "Maybe" $
      SMT.PT ["T"]
             [ SMT.Cons "Nothing" []
             , SMT.Cons "Just" [SMT.SV "just" $ SMT.SComb "T" []]
             ]
  , SMT.Comment "Either type:"
  , SMT.DeclareDatatype "Either" $
      SMT.PT ["T1", "T2"]
             [ SMT.Cons "Left"  [SMT.SV "left"  $ SMT.SComb "T1" []]
             , SMT.Cons "Right" [SMT.SV "right" $ SMT.SComb "T2" []]
             ]
  , SMT.Comment "Ordering type:"
  , SMT.DeclareDatatype "Ordering" $
      SMT.MT [SMT.Cons "LT" [], SMT.Cons "EQ" [], SMT.Cons "GT" []]
  , SMT.Comment "Dict type (to represent dictionary variables):"
  , SMT.DeclareDatatype "Dict" $
      SMT.PT ["T"] [SMT.Cons "Dict" [SMT.SV "dict" $ SMT.SComb "T" []]]
  ]

---------------------------------------------------------------------------
-- Decodes a SMT term into a string representation of a Curry term.
decodeTerm :: SMT.Term -> String
decodeTerm (SMT.TConst tconst) = case tconst of
                                   SMT.Num num -> negParen num
                                   SMT.Dec dec -> negParen dec
                                   SMT.Str str -> "'" ++ str ++ "'"
  where
    negParen n = if n < 0 then "(" ++ show n ++ ")" else show n
decodeTerm (SMT.TComb qIdent terms) =
  paren $ decodeIdent qIdent ++ [' ' | not $ null terms]
          ++ (unwords $ map decodeTerm terms)
  where
    paren s = if null terms then s else "(" ++ s ++ ")"
    decodeIdent (SMT.Id ident)   = decodeIdent' ident
    decodeIdent (SMT.As ident _) = decodeIdent' ident
    decodeIdent' ident           =
      case lookup ident $ map (\(x,y) -> (y,x)) transPrimCons of
        Just ":"    -> "(:)"
        Just ident' -> ident'
        Nothing     -> let (modname, ident') = break (== '_') ident in
                         modname ++ (if null modname then "" else ".")
                         ++ dropWhile (== '_') ident'
decodeTerm (SMT.Let decls term) =
  "(let {"
  ++ intercalate "; " (map (\(sym, t) -> sym ++ " = " ++ decodeTerm t) decls)
  ++ "} in " ++ decodeTerm term ++ ")"
decodeTerm (SMT.Forall _ term) = decodeTerm term
decodeTerm (SMT.Exists _ term) = decodeTerm term
decodeTerm (SMT.Match term pats) =
  "case " ++ decodeTerm term ++ " of {"
  ++ intercalate "; " (map decodeBranch pats) ++ "}"
  where
    decodeBranch ((SMT.PComb cons args), branch) =
      cons ++ " " ++ unwords args ++ " -> " ++ decodeTerm branch
decodeTerm (SMT.Annot term _) = decodeTerm term

-- Check if the given functions fail with their given arguments and return only
-- those that do.
checkFuncs :: (QName -> TypeExpr -> [[String]] -> IO [Bool])
           -> [TAFuncDecl] -> [((QName, Int), [TypeExpr])]
           -> [((QName, Int), [([String], [String])])]
           -> IO [(String, [([String], [String])])]
checkFuncs _ _     _           []                        = return []
checkFuncs eval funcs argtypes ((fn@(qn, _), args) : fs) = do
  let ftype = funcType $ fromJust $ find (\f -> funcName f == qn) funcs
      sub = concatMap (uncurry matchTypeVars)
            $ zip (argTypes ftype) $ fromJust $ lookup fn argtypes
      ftype' = updTVars (const $ TCons ("Prelude", "()") [])
               $ updTVars (\i -> fromJust $ lookup i sub) $ resultType ftype
  results <- eval qn ftype' $ map fst args
  let args' = mapMaybe
                (\(r, arg) -> if not r then Just arg else Nothing)
                $ zip results args
  if null args'
    then checkFuncs eval funcs argtypes fs
    else checkFuncs eval funcs argtypes fs
           >>= \fs' -> return $ (snd qn, args') : fs'

-- Matches the first argument with the second and returns the corresponding
-- type expression for each type variable in the first argument.
matchTypeVars :: TypeExpr -> TypeExpr -> [(TVarIndex, TypeExpr)]
matchTypeVars (TVar i) texpr = [(i, stripForall texpr)]
matchTypeVars (FuncType arg res) texpr = case stripForall texpr of
  FuncType arg' res' -> union (matchTypeVars arg arg') (matchTypeVars res res')
  _                  -> error "CheckSMT.matchTypeVars: unexpected case"
matchTypeVars (TCons qn args) texpr = case stripForall texpr of
  TCons qn' args' | qn == qn' -> foldr union [] $
                                   map (uncurry matchTypeVars) $ zip args args'
  _                           -> error "CheckSMT.matchTypeVars: unexpected case"
matchTypeVars (ForallType _ texp) texp' = matchTypeVars texp texp'

---------------------------------------------------------------------------
-- Evaluate the given function with `getAllValues` and returns `False` for each
-- argument for which the function fails.
evalNonFailFunc :: QName -> TypeExpr -> [[String]] -> IO [Bool]
evalNonFailFunc funcname functype args =
  evalFunc False funcname functype args >>= return . map (/= "[]")

-- Evaluate the precondition for the given function with `getAllValues` and
-- return if it holds for the given arguments.
evalPreCon :: QName -> TypeExpr -> [[String]] -> IO [Bool]
evalPreCon funcname _ args =
  evalFunc False (toPreCondQName funcname) boolType args
    >>= return . map (((&&) <$> and <*> not . null) . read)

-- Evaluate the given function with `getAllValues` and then evaluate if the
-- postcondition holds for all possible results.
evalPostCon :: QName -> TypeExpr -> [[String]] -> IO [Bool]
evalPostCon funcname functype args = do
  evalFunc True funcname functype args
    >>= return . map (((&&) <$> and <*> not . null) . read)

-- Evaluate a function with `getAllValues` and return the output of the
-- compiler.
-- The first argument controls if functions for checking post conditons should
-- be generated.
evalFunc :: Bool -> QName -> TypeExpr -> [[String]] -> IO [String]
evalFunc post funcname functype args = do
  let functype' = ioType $ listType functype
  writeFile "Eval.curry" $ unlines $
    [ "module Eval where"
    , "import Control.AllValues ( getAllValues )"
    , "import " ++ fst funcname
    , "eval :: IO ()"
    , "eval = sequence_ ["
      ++ (intercalate ", " $ map (\i -> "eval" ++ (if post then "Post" else "")
                                        ++ show i ++ " >>= print")
                                 [1..length args]) ++ "]"
    ] ++ map (\(i, arg) -> "eval" ++ show i ++ " :: "
                           ++ pPrint (ppTypeExp functype')
                           ++ "\neval" ++ show i ++ " = getAllValues ("
                           ++ pPrint (ppQName funcname) ++ " " ++ unwords arg
                           ++ ")"
                           ++ (if post then
                                 "\nevalPost" ++ show i ++ " :: IO [Bool]\n"
                                 ++ "evalPost" ++ show i ++ " = eval" ++ show i
                                 ++ " >>= mapM (getAllValues . "
                                 ++ pPrint (ppQName $ toPostCondQName funcname)
                                 ++ " " ++ unwords arg ++ ") >>= return . map "
                                 ++ "((&&) <$> and <*> not . null)"
                               else ""
                              )
             ) (zip [1..] args)
  (i,o,e) <- execCmd "curry -q :l Eval :eval eval :quit"
  hClose i
  hGetContents e >>= putStrLn
  out <- hGetContents o
  removeFile "Eval.curry"
  return $ lines out

---------------------------------------------------------------------------
-- Translate a typed variables to an SMT declaration:
typedVar2SMT :: (Int,TypeExpr, Maybe (QName, Int, Int)) -> [SMT.Command]
typedVar2SMT (i, te, Just (name, argPos, _)) =
  [ SMT.Comment $
      (case argPos of
        0 -> "result"
        _ -> show argPos
             ++ (case argPos of {1 -> "st"; 2 -> "nd"; 3 -> "rd"; _ -> "th"})
             ++ " argument")
      ++ " of '" ++ snd name ++ "'"
  , SMT.DeclareConst (var2SMT i) (polytype2sort te)
  ]
typedVar2SMT (i, te, Nothing) =
  [SMT.DeclareConst (var2SMT i) (polytype2sort te)]

-- Gets all type constructors of a type expression.
tconsOfTypeExpr :: TypeExpr -> [QName]
tconsOfTypeExpr (TVar _) = []
tconsOfTypeExpr (FuncType a b) =  union (tconsOfTypeExpr a) (tconsOfTypeExpr b)
tconsOfTypeExpr (TCons qName texps) =
  foldr union [qName] (map tconsOfTypeExpr texps)
tconsOfTypeExpr (ForallType _ te) =  tconsOfTypeExpr te

-- Group a list of pairs together according to the first component.
groupPairs :: Eq a => [(a, b)] -> [(a, [b])]
groupPairs xs = foldr groupPairs' [] (map fst xs)
  where
    -- groupPairs' :: Eq a => a -> [(a, [b])] -> [(a, [b])]
    groupPairs' x xs' | x `elem` (map fst xs') = xs'
                      | otherwise              = (x, findElems x) : xs'

    -- findElems :: Eq a => a -> [b]
    findElems x' = map snd $ filter ((== x') . fst) xs

---------------------------------------------------------------------------
--- Shows a text with line numbers prefixed:
showWithLineNums :: String -> String
showWithLineNums txt =
  let txtlines  = lines txt
      maxlog    = ilog (length txtlines + 1)
      showNum n = (take (maxlog - ilog n) (repeat ' ')) ++ show n ++ ": "
  in unlines . map (uncurry (++)) . zip (map showNum [1..]) $ txtlines

--- The value of `ilog n` is the floor of the logarithm
--- in the base 10 of `n`.
--- Fails if `n &lt;= 0`.
--- For positive integers, the returned value is
--- 1 less the number of digits in the decimal representation of `n`.
---
--- @param n - The argument.
--- @return the floor of the logarithm in the base 10 of `n`.
ilog :: Int -> Int
ilog n | n>0 = if n<10 then 0 else 1 + ilog (n `div` 10)