module STM
           (STM, TVar, newTVar, readTVar, writeTVar, 
            atomically, retry, orElse, proc, catchSTM) where


-- for ghc
import Prelude hiding (catch)
import Control.Concurrent
import Control.Exception hiding (throw)
import qualified Control.Exception (throw)
import Data.IORef
import System.IO.Unsafe


 {-
-- for hugs
import Concurrent
import IOExts

-- Additionally one modification at the end of this file is necessary 
-- for using this library within Hugs.
 -}

import List
-----------------------
-- The STM interface --
-----------------------

type ID = Integer

-- The STM monad itself
data STM a = STM (StmState -> IO (STMResult a))

instance Monad STM where
  (STM tr1)  >>= k = STM (\state -> do
                          stmRes <- tr1 state
                          case stmRes of
                            Success newState a ->
                               let (STM tr2) = k a in
                                 catch (tr2 newState)
                                       (\e -> return  (Exception newState e))
                            Retry newState -> return (Retry newState)
                            
                            InValid -> return InValid
                            Exception newState e -> return (Exception newState e)
                       )
  return x      = STM (\state -> return (Success state x))

data StmState = TST {touchedTVars  :: [ID],
		     isValid       :: IO Bool,
		     commits       :: [IO ()],
		     notifys       :: [IO ()],
		     restores      :: [IO ()],
		     wait          :: IO (),
                     retryMVar     :: MVar ()}

data STMResult a = Retry StmState
	         | InValid
		 | Success StmState a
		 | Exception StmState Exception

initialState :: IO StmState
initialState = do
  atomicallyId <- getGlobalId
  rMVar <- newEmptyMVar
  return (TST {touchedTVars  = [],
	       isValid       = return True,
	       commits       = [return ()],
	       notifys       = [return ()],
	       restores      = [return ()],
	       wait          = return (),
               retryMVar     = rMVar})

-- Transactional variables
data TVar a = TVar (MVar (IORef a))  -- global TVar itself
                   ID                -- TVar identifier
                   (MVar [MVar ()])  -- wait queue on retry

newTVar   :: a -> STM (TVar a)
newTVar v = STM (\stmState -> do
                    id <- getGlobalId
                    newTVarVal <- newIORef v
                    newTVarRef <- newMVar newTVarVal
                    newWaitQ <- newMVar []
                    let tVar = (TVar newTVarRef id newWaitQ)
		    return (Success stmState tVar))

readTVar  :: Show a => TVar a -> STM a
readTVar (TVar tVarRef id waitQ) = STM (\stmState -> do
    if elem id (touchedTVars stmState)
     then do
       takeMVar globalLock
       valid <- isValid stmState
       if valid
        then do
          sequence_ (commits stmState)
	  tVarVal <- readMVar tVarRef
          val <- readIORef tVarVal
	  sequence_ (restores stmState)
	  putMVar globalLock ()
	  return (Success stmState val)
        else do
         putMVar globalLock ()
         return InValid
     else do
       tVarVal <- readMVar tVarRef
       let newState = stmState{touchedTVars=id:touchedTVars stmState,
                               isValid = do 
                                    tVarVal' <- readMVar tVarRef
				    (return (tVarVal==tVarVal') 
				     >>+ isValid stmState),
                               wait = do 
				    queue <- takeMVar waitQ
                                    putMVar waitQ 
					    (retryMVar stmState:queue)
                                    wait stmState}
       val <- readIORef tVarVal
       return (Success newState val))


writeTVar :: Show a => TVar a -> a -> STM ()
writeTVar (TVar tVarRef id waitQ) v = STM (\stmState -> do
   let (co:cos) = commits stmState
       (no:nos) = notifys stmState
       (rest:rests) = restores stmState
   tVarVal <- readMVar tVarRef
   let newState = 
         stmState{touchedTVars=if elem id (touchedTVars stmState)
                                 then touchedTVars stmState
                                 else id:touchedTVars stmState,
                  commits = (do co
	                        newTVarVal <- newIORef v
		                takeMVar tVarRef
                                putMVar tVarRef newTVarVal):cos,
                  notifys = (no >> fNotify waitQ):nos,
		  restores = (do takeMVar tVarRef
			         putMVar tVarRef tVarVal
			         rest):rests}
   return (Success newState ()))

-- Running STM computations
atomically :: STM a -> IO a
atomically stmAction = do
  iState <- initialState 
  atomically' stmAction iState
  where
    atomically' :: STM a -> StmState -> IO a
    atomically' stmAction state = do
      stmResult <- catch (startSTM stmAction state)
                         (\e -> return (Exception state e))
      case stmResult of
        Exception _ e -> do
          putStr("t")
          Control.Exception.throw e
        Retry newState -> do
	  wait newState
          takeMVar (retryMVar state)   -- suspend
          atomically' stmAction state
	InValid -> do
          atomically' stmAction state
        Success newState res -> do
          takeMVar globalLock
          valid <- (isValid newState)
          if valid
            then do
              --putStr "+"
              sequence_ (reverse (commits newState))
              sequence_ (reverse (notifys newState))
              putMVar globalLock ()
              return res
            else do
              --putStr "*"
              putMVar globalLock ()
    	      atomically' stmAction state

retry  :: STM a
retry =
  STM (\stmState -> do
         takeMVar globalLock        
         valid <- isValid stmState
         putMVar globalLock ()
         if valid
           then return (Retry stmState)
           else return InValid)


orElse :: STM a -> STM a -> STM a
orElse (STM stm1) (STM stm2) =
  STM (\(stmState@TST{commits  = fCommits,
                         notifys  = fNotifys,
                         restores = fRestores}) -> do
         stm1Res <- stm1 stmState{commits = return ():fCommits,
                                     notifys = return ():fNotifys,
				     restores = return ():fRestores}
         case stm1Res of
           Retry newState@TST{commits  = nCommits,notifys  = nNotifys,
                              restores = nRestores} 
             -> stm2 newState{commits = tail nCommits,
       		              notifys = tail nNotifys,
			      restores = tail nRestores}
	   _ -> return stm1Res)

{-
-- Throwing STM Exceptions
throw :: Exception -> STM a
throw e = STM (\stmState -> do
  takeMVar globalLock
  valid <- isValid stmState
  putMVar globalLock ()
  if valid
    then return (Exception stmState e)
    else return InValid)
-}

catchSTM :: STM a -> (Exception -> STM a) -> STM a
catchSTM (STM stm) eHandler = STM (\stmState -> do
  res <- stm stmState
  case res of
    Exception newState e  -> do
                 let (STM stmEx) = eHandler e
                 stmEx newState
    _         -> return res)

-------------------
-- Miscellaneous --
-------------------

startSTM :: STM a -> StmState -> IO (STMResult a)
startSTM stmAct@(STM stm) state = stm state

proc :: IO a -> STM a
proc io = STM (\stmState -> do
                  res <- io
                  return (Success stmState res))

globalLock :: MVar ()
globalLock = unsafePerformIO (newMVar ())

globalCount :: MVar Integer --[ID]
globalCount  = unsafePerformIO (newMVar 0) --(newMVar [0..])

getGlobalId :: IO Integer
getGlobalId = do
  num <- takeMVar globalCount 
  putMVar globalCount (num+1) --(tail nums)
  return num

(>>+) :: IO Bool -> IO Bool -> IO Bool
a1 >>+ a2 = do
  b <- a1
  if b then a2
       else return False

fNotify :: MVar [MVar ()] -> IO ()
fNotify waitQ = do
  queue <- takeMVar waitQ
  mapM_ (flip tryPutMVar ()) queue -- more efficient in ghc
  --mapM_ safePutMVar queue        -- has to be used in Hugs
  putMVar waitQ []

-- tryPutMVar is not defined in Hugs, but this works in combination with
-- setting a global lock/locking relevant TVars
safePutMVar :: MVar () -> IO ()
safePutMVar mVar = do
  b <- isEmptyMVar mVar
  if b then putMVar mVar ()
       else return ()
