{-# LANGUAGE CPP, DeriveDataTypeable, FlexibleInstances, MultiParamTypeClasses #-}

#if __GLASGOW_HASKELL__ >= 701
{-# LANGUAGE Trustworthy #-}
#endif

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 904
#define HAS_UNLIFTED_ARRAY 1
#endif

#if defined(HAS_UNLIFTED_ARRAY)
{-# LANGUAGE MagicHash, UnboxedTuples #-}
#endif

-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.STM.TArray
-- Copyright   :  (c) The University of Glasgow 2005
-- License     :  BSD-style (see the file libraries/base/LICENSE)
--
-- Maintainer  :  libraries@haskell.org
-- Stability   :  experimental
-- Portability :  non-portable (requires STM)
--
-- TArrays: transactional arrays, for use in the STM monad.
--
-----------------------------------------------------------------------------

module Control.Concurrent.STM.TArray (
    TArray
) where

import Control.Monad.STM (STM, atomically)
import Data.Typeable (Typeable)
#if defined(HAS_UNLIFTED_ARRAY)
import Control.Concurrent.STM.TVar (readTVar, readTVarIO, writeTVar)
import Data.Array.Base (safeRangeSize, MArray(..))
import Data.Ix (Ix)
import GHC.Conc (STM(..), TVar(..))
import GHC.Exts
import GHC.IO (IO(..))
#else
import Control.Concurrent.STM.TVar (TVar, newTVar, newTVarIO, readTVar, readTVarIO, writeTVar)
import Data.Array (Array, bounds, listArray)
import Data.Array.Base (safeRangeSize, unsafeAt, MArray(..), IArray(numElements))
#endif

-- | 'TArray' is a transactional array, supporting the usual 'MArray'
-- interface for mutable arrays.
--
-- It is conceptually implemented as @Array i (TVar e)@.
#if defined(HAS_UNLIFTED_ARRAY)
data TArray i e = TArray
    !i   -- lower bound
    !i   -- upper bound
    !Int -- size
    (Array# (TVar# RealWorld e))
    deriving (Typeable)

instance (Eq i, Eq e) => Eq (TArray i e) where
    (TArray i
l1 i
u1 Int
n1 Array# (TVar# RealWorld e)
arr1#) == :: TArray i e -> TArray i e -> Bool
== (TArray i
l2 i
u2 Int
n2 Array# (TVar# RealWorld e)
arr2#) =
        -- each `TArray` has its own `TVar`s, so it's sufficient to compare the first one
        if Int
n1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Int
n2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 else i
l1 i -> i -> Bool
forall a. Eq a => a -> a -> Bool
== i
l2 Bool -> Bool -> Bool
&& i
u1 i -> i -> Bool
forall a. Eq a => a -> a -> Bool
== i
u2 Bool -> Bool -> Bool
&& Int# -> Bool
isTrue# (TVar# RealWorld e -> TVar# RealWorld e -> Int#
forall s a. TVar# s a -> TVar# s a -> Int#
sameTVar# (Array# (TVar# RealWorld e) -> TVar# RealWorld e
forall e. Array# (TVar# RealWorld e) -> TVar# RealWorld e
unsafeFirstT Array# (TVar# RealWorld e)
arr1#) (Array# (TVar# RealWorld e) -> TVar# RealWorld e
forall e. Array# (TVar# RealWorld e) -> TVar# RealWorld e
unsafeFirstT Array# (TVar# RealWorld e)
arr2#))
      where
        unsafeFirstT :: Array# (TVar# RealWorld e) -> TVar# RealWorld e
        unsafeFirstT :: forall e. Array# (TVar# RealWorld e) -> TVar# RealWorld e
unsafeFirstT Array# (TVar# RealWorld e)
arr# = case Array# (TVar# RealWorld e) -> Int# -> (# TVar# RealWorld e #)
forall a. Array# a -> Int# -> (# a #)
indexArray# Array# (TVar# RealWorld e)
arr# Int#
0# of (# TVar# RealWorld e
e #) -> TVar# RealWorld e
e

newTArray# :: Ix i => (i, i) -> e -> State# RealWorld -> (# State# RealWorld, TArray i e #)
newTArray# :: forall i e.
Ix i =>
(i, i)
-> e -> State# RealWorld -> (# State# RealWorld, TArray i e #)
newTArray# b :: (i, i)
b@(i
l, i
u) e
e = \State# RealWorld
s1# ->
    case (i, i) -> Int
forall i. Ix i => (i, i) -> Int
safeRangeSize (i, i)
b of
        n :: Int
n@(I# Int#
n#) -> case e -> State# RealWorld -> (# State# RealWorld, TVar# RealWorld e #)
forall a d. a -> State# d -> (# State# d, TVar# d a #)
newTVar# e
e State# RealWorld
s1# of
            (# State# RealWorld
s2#, TVar# RealWorld e
initial_tvar# #) -> case Int#
-> TVar# RealWorld e
-> State# RealWorld
-> (# State# RealWorld,
      MutableArray# RealWorld (TVar# RealWorld e) #)
forall a d.
Int# -> a -> State# d -> (# State# d, MutableArray# d a #)
newArray# Int#
n# TVar# RealWorld e
initial_tvar# State# RealWorld
s2# of
                (# State# RealWorld
s3#, MutableArray# RealWorld (TVar# RealWorld e)
marr# #) ->
                    let go :: Int# -> State# RealWorld -> State# RealWorld
go Int#
i# = \State# RealWorld
s4# -> case e -> State# RealWorld -> (# State# RealWorld, TVar# RealWorld e #)
forall a d. a -> State# d -> (# State# d, TVar# d a #)
newTVar# e
e State# RealWorld
s4# of
                            (# State# RealWorld
s5#, TVar# RealWorld e
tvar# #) -> case MutableArray# RealWorld (TVar# RealWorld e)
-> Int#
-> TVar# RealWorld e
-> State# RealWorld
-> State# RealWorld
forall d a. MutableArray# d a -> Int# -> a -> State# d -> State# d
writeArray# MutableArray# RealWorld (TVar# RealWorld e)
marr# Int#
i# TVar# RealWorld e
tvar# State# RealWorld
s5# of
                                State# RealWorld
s6# -> if Int# -> Bool
isTrue# (Int#
i# Int# -> Int# -> Int#
==# Int#
n# Int# -> Int# -> Int#
-# Int#
1#) then State# RealWorld
s6# else Int# -> State# RealWorld -> State# RealWorld
go (Int#
i# Int# -> Int# -> Int#
+# Int#
1#) State# RealWorld
s6#
                    in case MutableArray# RealWorld (TVar# RealWorld e)
-> State# RealWorld
-> (# State# RealWorld, Array# (TVar# RealWorld e) #)
forall d a.
MutableArray# d a -> State# d -> (# State# d, Array# a #)
unsafeFreezeArray# MutableArray# RealWorld (TVar# RealWorld e)
marr# (if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 then State# RealWorld
s3# else Int# -> State# RealWorld -> State# RealWorld
go Int#
1# State# RealWorld
s3#) of
                        (# State# RealWorld
s7#, Array# (TVar# RealWorld e)
arr# #) -> (# State# RealWorld
s7#, i -> i -> Int -> Array# (TVar# RealWorld e) -> TArray i e
forall i e.
i -> i -> Int -> Array# (TVar# RealWorld e) -> TArray i e
TArray i
l i
u Int
n Array# (TVar# RealWorld e)
arr# #)

instance MArray TArray e STM where
    getBounds :: forall i. Ix i => TArray i e -> STM (i, i)
getBounds (TArray i
l i
u Int
_ Array# (TVar# RealWorld e)
_) = (i, i) -> STM (i, i)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (i
l, i
u)
    getNumElements :: forall i. Ix i => TArray i e -> STM Int
getNumElements (TArray i
_ i
_ Int
n Array# (TVar# RealWorld e)
_) = Int -> STM Int
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n
    newArray :: forall i. Ix i => (i, i) -> e -> STM (TArray i e)
newArray (i, i)
b e
e = (State# RealWorld -> (# State# RealWorld, TArray i e #))
-> STM (TArray i e)
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> STM a
STM ((State# RealWorld -> (# State# RealWorld, TArray i e #))
 -> STM (TArray i e))
-> (State# RealWorld -> (# State# RealWorld, TArray i e #))
-> STM (TArray i e)
forall a b. (a -> b) -> a -> b
$ (i, i)
-> e -> State# RealWorld -> (# State# RealWorld, TArray i e #)
forall i e.
Ix i =>
(i, i)
-> e -> State# RealWorld -> (# State# RealWorld, TArray i e #)
newTArray# (i, i)
b e
e
    unsafeRead :: forall i. Ix i => TArray i e -> Int -> STM e
unsafeRead (TArray i
_ i
_ Int
_ Array# (TVar# RealWorld e)
arr#) (I# Int#
i#) = case Array# (TVar# RealWorld e) -> Int# -> (# TVar# RealWorld e #)
forall a. Array# a -> Int# -> (# a #)
indexArray# Array# (TVar# RealWorld e)
arr# Int#
i# of
        (# TVar# RealWorld e
tvar# #) -> TVar e -> STM e
forall a. TVar a -> STM a
readTVar (TVar# RealWorld e -> TVar e
forall a. TVar# RealWorld a -> TVar a
TVar TVar# RealWorld e
tvar#)
    unsafeWrite :: forall i. Ix i => TArray i e -> Int -> e -> STM ()
unsafeWrite (TArray i
_ i
_ Int
_ Array# (TVar# RealWorld e)
arr#) (I# Int#
i#) e
e = case Array# (TVar# RealWorld e) -> Int# -> (# TVar# RealWorld e #)
forall a. Array# a -> Int# -> (# a #)
indexArray# Array# (TVar# RealWorld e)
arr# Int#
i# of
        (# TVar# RealWorld e
tvar# #) -> TVar e -> e -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (TVar# RealWorld e -> TVar e
forall a. TVar# RealWorld a -> TVar a
TVar TVar# RealWorld e
tvar#) e
e

-- | Writes are slow in `IO`.
instance MArray TArray e IO where
    getBounds :: forall i. Ix i => TArray i e -> IO (i, i)
getBounds (TArray i
l i
u Int
_ Array# (TVar# RealWorld e)
_) = (i, i) -> IO (i, i)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (i
l, i
u)
    getNumElements :: forall i. Ix i => TArray i e -> IO Int
getNumElements (TArray i
_ i
_ Int
n Array# (TVar# RealWorld e)
_) = Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n
    newArray :: forall i. Ix i => (i, i) -> e -> IO (TArray i e)
newArray (i, i)
b e
e = (State# RealWorld -> (# State# RealWorld, TArray i e #))
-> IO (TArray i e)
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, TArray i e #))
 -> IO (TArray i e))
-> (State# RealWorld -> (# State# RealWorld, TArray i e #))
-> IO (TArray i e)
forall a b. (a -> b) -> a -> b
$ (i, i)
-> e -> State# RealWorld -> (# State# RealWorld, TArray i e #)
forall i e.
Ix i =>
(i, i)
-> e -> State# RealWorld -> (# State# RealWorld, TArray i e #)
newTArray# (i, i)
b e
e
    unsafeRead :: forall i. Ix i => TArray i e -> Int -> IO e
unsafeRead (TArray i
_ i
_ Int
_ Array# (TVar# RealWorld e)
arr#) (I# Int#
i#) = case Array# (TVar# RealWorld e) -> Int# -> (# TVar# RealWorld e #)
forall a. Array# a -> Int# -> (# a #)
indexArray# Array# (TVar# RealWorld e)
arr# Int#
i# of
        (# TVar# RealWorld e
tvar# #) -> TVar e -> IO e
forall a. TVar a -> IO a
readTVarIO (TVar# RealWorld e -> TVar e
forall a. TVar# RealWorld a -> TVar a
TVar TVar# RealWorld e
tvar#)
    unsafeWrite :: forall i. Ix i => TArray i e -> Int -> e -> IO ()
unsafeWrite (TArray i
_ i
_ Int
_ Array# (TVar# RealWorld e)
arr#) (I# Int#
i#) e
e = case Array# (TVar# RealWorld e) -> Int# -> (# TVar# RealWorld e #)
forall a. Array# a -> Int# -> (# a #)
indexArray# Array# (TVar# RealWorld e)
arr# Int#
i# of
        (# TVar# RealWorld e
tvar# #) -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar e -> e -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (TVar# RealWorld e -> TVar e
forall a. TVar# RealWorld a -> TVar a
TVar TVar# RealWorld e
tvar#) e
e
#else
newtype TArray i e = TArray (Array i (TVar e)) deriving (Eq, Typeable)

instance MArray TArray e STM where
    getBounds (TArray a) = return (bounds a)
    getNumElements (TArray a) = return (numElements a)
    newArray b e = do
        a <- rep (safeRangeSize b) (newTVar e)
        return $ TArray (listArray b a)
    unsafeRead (TArray a) i = readTVar $ unsafeAt a i
    unsafeWrite (TArray a) i e = writeTVar (unsafeAt a i) e

    {-# INLINE newArray #-}

-- | Writes are slow in `IO`.
instance MArray TArray e IO where
    getBounds (TArray a) = return (bounds a)
    getNumElements (TArray a) = return (numElements a)
    newArray b e = do
        a <- rep (safeRangeSize b) (newTVarIO e)
        return $ TArray (listArray b a)
    unsafeRead (TArray a) i = readTVarIO $ unsafeAt a i
    unsafeWrite (TArray a) i e = atomically $ writeTVar (unsafeAt a i) e

    {-# INLINE newArray #-}

-- | Like 'replicateM', but uses an accumulator to prevent stack overflows.
-- Unlike 'replicateM', the returned list is in reversed order.
-- This doesn't matter though since this function is only used to create
-- arrays with identical elements.
rep :: Monad m => Int -> m a -> m [a]
rep n m = go n []
    where
      go 0 xs = return xs
      go i xs = do
          x <- m
          go (i - 1) (x : xs)
#endif