{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE Unsafe #-}

-- | Variant of @MutVar@ that has one less indirection for primitive types.
-- The difference is illustrated by comparing @MutVar Int@ and @PrimVar Int@:
--
-- * @MutVar Int@: @MutVar# --> I#@
-- * @PrimVar Int@: @MutableByteArray#@
--
-- This module is adapted from a module in Edward Kmett\'s @prim-ref@ library.
module Data.Primitive.PrimVar
  (
  -- * Primitive References
    PrimVar(..)
  , newPrimVar
  , newPinnedPrimVar
  , newAlignedPinnedPrimVar
  , readPrimVar
  , writePrimVar
  , modifyPrimVar
  , primVarContents
  , primVarToMutablePrimArray
  -- * Atomic Operations
  -- $atomic
  , casInt
  , fetchAddInt
  , fetchSubInt
  , fetchAndInt
  , fetchNandInt
  , fetchOrInt
  , fetchXorInt
  , atomicReadInt
  , atomicWriteInt
  ) where

import Control.Monad.Primitive
import Data.Primitive
import GHC.Exts
import GHC.Ptr (castPtr)

--------------------------------------------------------------------------------
-- * Primitive References
--------------------------------------------------------------------------------

-- | A 'PrimVar' behaves like a single-element mutable primitive array.
newtype PrimVar s a = PrimVar (MutablePrimArray s a)

type role PrimVar nominal nominal

-- | Create a primitive reference.
newPrimVar :: (PrimMonad m, Prim a) => a -> m (PrimVar (PrimState m) a)
newPrimVar :: forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar a
a = do
  MutablePrimArray (PrimState m) a
m <- Int -> m (MutablePrimArray (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
1
  MutablePrimArray (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray (PrimState m) a
m Int
0 a
a
  PrimVar (PrimState m) a -> m (PrimVar (PrimState m) a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutablePrimArray (PrimState m) a -> PrimVar (PrimState m) a
forall s a. MutablePrimArray s a -> PrimVar s a
PrimVar MutablePrimArray (PrimState m) a
m)
{-# INLINE newPrimVar #-}

-- | Create a pinned primitive reference.
newPinnedPrimVar :: (PrimMonad m, Prim a) => a -> m (PrimVar (PrimState m) a)
newPinnedPrimVar :: forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPinnedPrimVar a
a = do
  MutablePrimArray (PrimState m) a
m <- Int -> m (MutablePrimArray (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPinnedPrimArray Int
1
  MutablePrimArray (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray (PrimState m) a
m Int
0 a
a
  PrimVar (PrimState m) a -> m (PrimVar (PrimState m) a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutablePrimArray (PrimState m) a -> PrimVar (PrimState m) a
forall s a. MutablePrimArray s a -> PrimVar s a
PrimVar MutablePrimArray (PrimState m) a
m)
{-# INLINE newPinnedPrimVar #-}

-- | Create a pinned primitive reference with the appropriate alignment for its contents.
newAlignedPinnedPrimVar :: (PrimMonad m, Prim a) => a -> m (PrimVar (PrimState m) a)
newAlignedPinnedPrimVar :: forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newAlignedPinnedPrimVar a
a = do
  MutablePrimArray (PrimState m) a
m <- Int -> m (MutablePrimArray (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newAlignedPinnedPrimArray Int
1
  MutablePrimArray (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray (PrimState m) a
m Int
0 a
a
  PrimVar (PrimState m) a -> m (PrimVar (PrimState m) a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutablePrimArray (PrimState m) a -> PrimVar (PrimState m) a
forall s a. MutablePrimArray s a -> PrimVar s a
PrimVar MutablePrimArray (PrimState m) a
m)
{-# INLINE newAlignedPinnedPrimVar #-}

-- | Read a value from the 'PrimVar'.
readPrimVar :: (PrimMonad m, Prim a) => PrimVar (PrimState m) a -> m a
readPrimVar :: forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar (PrimVar MutablePrimArray (PrimState m) a
m) = MutablePrimArray (PrimState m) a -> Int -> m a
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray (PrimState m) a
m Int
0
{-# INLINE readPrimVar #-}

-- | Write a value to the 'PrimVar'.
writePrimVar :: (PrimMonad m, Prim a) => PrimVar (PrimState m) a -> a -> m ()
writePrimVar :: forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar (PrimVar MutablePrimArray (PrimState m) a
m) a
a = MutablePrimArray (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray (PrimState m) a
m Int
0 a
a
{-# INLINE writePrimVar #-}

-- | Mutate the contents of a 'PrimVar'.
modifyPrimVar :: (PrimMonad m, Prim a) => PrimVar (PrimState m) a -> (a -> a) -> m ()
modifyPrimVar :: forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> (a -> a) -> m ()
modifyPrimVar PrimVar (PrimState m) a
pv a -> a
f = do
    a
x <- PrimVar (PrimState m) a -> m a
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar (PrimState m) a
pv
    PrimVar (PrimState m) a -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar (PrimState m) a
pv (a -> a
f a
x)
{-# INLINE modifyPrimVar #-}

instance Eq (PrimVar s a) where
  PrimVar MutablePrimArray s a
m == :: PrimVar s a -> PrimVar s a -> Bool
== PrimVar MutablePrimArray s a
n = MutablePrimArray s a -> MutablePrimArray s a -> Bool
forall s a. MutablePrimArray s a -> MutablePrimArray s a -> Bool
sameMutablePrimArray MutablePrimArray s a
m MutablePrimArray s a
n
  {-# INLINE (==) #-}

-- | Yield a pointer to the data of a 'PrimVar'. This operation is only safe on pinned byte arrays allocated by
-- 'newPinnedPrimVar' or 'newAlignedPinnedPrimVar'.
primVarContents :: PrimVar s a -> Ptr a
primVarContents :: forall s a. PrimVar s a -> Ptr a
primVarContents (PrimVar MutablePrimArray s a
m) = Ptr a -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr (Ptr a -> Ptr a) -> Ptr a -> Ptr a
forall a b. (a -> b) -> a -> b
$ MutablePrimArray s a -> Ptr a
forall s a. MutablePrimArray s a -> Ptr a
mutablePrimArrayContents MutablePrimArray s a
m
{-# INLINE primVarContents #-}

-- | Convert a 'PrimVar' to a one-elment 'MutablePrimArray'.
primVarToMutablePrimArray :: PrimVar s a -> MutablePrimArray s a
primVarToMutablePrimArray :: forall s a. PrimVar s a -> MutablePrimArray s a
primVarToMutablePrimArray (PrimVar MutablePrimArray s a
m) = MutablePrimArray s a
m
{-# INLINE primVarToMutablePrimArray #-}

--------------------------------------------------------------------------------
-- * Atomic Operations
--------------------------------------------------------------------------------

-- $atomic
-- Atomic operations on `PrimVar s Int`. All atomic operations imply a full memory barrier.

-- | Given a primitive reference, the expected old value, and the new value, perform an atomic compare and swap i.e. write the new value if the current value matches the provided old value. Returns the value of the element before the operation.
casInt :: PrimMonad m => PrimVar (PrimState m) Int -> Int -> Int -> m Int
casInt :: forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> Int -> m Int
casInt (PrimVar (MutablePrimArray MutableByteArray# (PrimState m)
m)) (I# Int#
old) (I# Int#
new) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s -> case MutableByteArray# (PrimState m)
-> Int#
-> Int#
-> Int#
-> State# (PrimState m)
-> (# State# (PrimState m), Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> Int# -> State# d -> (# State# d, Int# #)
casIntArray# MutableByteArray# (PrimState m)
m Int#
0# Int#
old Int#
new State# (PrimState m)
s of
  (# State# (PrimState m)
s', Int#
result #) -> (# State# (PrimState m)
s', Int# -> Int
I# Int#
result #)
{-# INLINE casInt #-}

-- | Given a reference, and a value to add, atomically add the value to the element. Returns the value of the element before the operation.
fetchAddInt :: PrimMonad m => PrimVar (PrimState m) Int -> Int -> m Int
fetchAddInt :: forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> m Int
fetchAddInt (PrimVar (MutablePrimArray MutableByteArray# (PrimState m)
m)) (I# Int#
x) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s -> case MutableByteArray# (PrimState m)
-> Int#
-> Int#
-> State# (PrimState m)
-> (# State# (PrimState m), Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchAddIntArray# MutableByteArray# (PrimState m)
m Int#
0# Int#
x State# (PrimState m)
s of
  (# State# (PrimState m)
s', Int#
result #) -> (# State# (PrimState m)
s', Int# -> Int
I# Int#
result #)
{-# INLINE fetchAddInt #-}

-- | Given a reference, and a value to subtract, atomically subtract the value from the element. Returns the value of the element before the operation.
fetchSubInt :: PrimMonad m => PrimVar (PrimState m) Int -> Int -> m Int
fetchSubInt :: forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> m Int
fetchSubInt (PrimVar (MutablePrimArray MutableByteArray# (PrimState m)
m)) (I# Int#
x) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s -> case MutableByteArray# (PrimState m)
-> Int#
-> Int#
-> State# (PrimState m)
-> (# State# (PrimState m), Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchSubIntArray# MutableByteArray# (PrimState m)
m Int#
0# Int#
x State# (PrimState m)
s of
  (# State# (PrimState m)
s', Int#
result #) -> (# State# (PrimState m)
s', Int# -> Int
I# Int#
result #)
{-# INLINE fetchSubInt #-}

-- | Given a reference, and a value to bitwise and, atomically and the value with the element. Returns the value of the element before the operation.
fetchAndInt :: PrimMonad m => PrimVar (PrimState m) Int -> Int -> m Int
fetchAndInt :: forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> m Int
fetchAndInt (PrimVar (MutablePrimArray MutableByteArray# (PrimState m)
m)) (I# Int#
x) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s -> case MutableByteArray# (PrimState m)
-> Int#
-> Int#
-> State# (PrimState m)
-> (# State# (PrimState m), Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchAndIntArray# MutableByteArray# (PrimState m)
m Int#
0# Int#
x State# (PrimState m)
s of
  (# State# (PrimState m)
s', Int#
result #) -> (# State# (PrimState m)
s', Int# -> Int
I# Int#
result #)
{-# INLINE fetchAndInt #-}

-- | Given a reference, and a value to bitwise nand, atomically nand the value with the element. Returns the value of the element before the operation.
fetchNandInt :: PrimMonad m => PrimVar (PrimState m) Int -> Int -> m Int
fetchNandInt :: forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> m Int
fetchNandInt (PrimVar (MutablePrimArray MutableByteArray# (PrimState m)
m)) (I# Int#
x) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s -> case MutableByteArray# (PrimState m)
-> Int#
-> Int#
-> State# (PrimState m)
-> (# State# (PrimState m), Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchNandIntArray# MutableByteArray# (PrimState m)
m Int#
0# Int#
x State# (PrimState m)
s of
  (# State# (PrimState m)
s', Int#
result #) -> (# State# (PrimState m)
s', Int# -> Int
I# Int#
result #)
{-# INLINE fetchNandInt #-}

-- | Given a reference, and a value to bitwise or, atomically or the value with the element. Returns the value of the element before the operation.
fetchOrInt :: PrimMonad m => PrimVar (PrimState m) Int -> Int -> m Int
fetchOrInt :: forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> m Int
fetchOrInt (PrimVar (MutablePrimArray MutableByteArray# (PrimState m)
m)) (I# Int#
x) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s -> case MutableByteArray# (PrimState m)
-> Int#
-> Int#
-> State# (PrimState m)
-> (# State# (PrimState m), Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchOrIntArray# MutableByteArray# (PrimState m)
m Int#
0# Int#
x State# (PrimState m)
s of
  (# State# (PrimState m)
s', Int#
result #) -> (# State# (PrimState m)
s', Int# -> Int
I# Int#
result #)
{-# INLINE fetchOrInt #-}

-- | Given a reference, and a value to bitwise xor, atomically xor the value with the element. Returns the value of the element before the operation.
fetchXorInt :: PrimMonad m => PrimVar (PrimState m) Int -> Int -> m Int
fetchXorInt :: forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> m Int
fetchXorInt (PrimVar (MutablePrimArray MutableByteArray# (PrimState m)
m)) (I# Int#
x) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s -> case MutableByteArray# (PrimState m)
-> Int#
-> Int#
-> State# (PrimState m)
-> (# State# (PrimState m), Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchXorIntArray# MutableByteArray# (PrimState m)
m Int#
0# Int#
x State# (PrimState m)
s of
  (# State# (PrimState m)
s', Int#
result #) -> (# State# (PrimState m)
s', Int# -> Int
I# Int#
result #)
{-# INLINE fetchXorInt #-}

-- | Given a reference, atomically read an element.
atomicReadInt :: PrimMonad m => PrimVar (PrimState m) Int -> m Int
atomicReadInt :: forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> m Int
atomicReadInt (PrimVar (MutablePrimArray MutableByteArray# (PrimState m)
m)) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s -> case MutableByteArray# (PrimState m)
-> Int# -> State# (PrimState m) -> (# State# (PrimState m), Int# #)
forall d.
MutableByteArray# d -> Int# -> State# d -> (# State# d, Int# #)
atomicReadIntArray# MutableByteArray# (PrimState m)
m Int#
0# State# (PrimState m)
s of
  (# State# (PrimState m)
s', Int#
result #) -> (# State# (PrimState m)
s', Int# -> Int
I# Int#
result #)
{-# INLINE atomicReadInt #-}

-- | Given a reference, atomically write an element.
atomicWriteInt :: PrimMonad m => PrimVar (PrimState m) Int -> Int -> m ()
atomicWriteInt :: forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> m ()
atomicWriteInt (PrimVar (MutablePrimArray MutableByteArray# (PrimState m)
m)) (I# Int#
x) = (State# (PrimState m) -> State# (PrimState m)) -> m ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ ((State# (PrimState m) -> State# (PrimState m)) -> m ())
-> (State# (PrimState m) -> State# (PrimState m)) -> m ()
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s -> MutableByteArray# (PrimState m)
-> Int# -> Int# -> State# (PrimState m) -> State# (PrimState m)
forall d.
MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
atomicWriteIntArray# MutableByteArray# (PrimState m)
m Int#
0# Int#
x State# (PrimState m)
s
{-# INLINE atomicWriteInt #-}