{-# LANGUAGE CPP             #-}
{-# LANGUAGE LambdaCase      #-}
{-# LANGUAGE NamedFieldPuns  #-}
{-# LANGUAGE RecordWildCards #-}
-- {-# OPTIONS_GHC -ddump-simpl -ddump-to-file -dsuppress-all #-}
module PureSAT.SparseMaxHeap (
    SparseHeap,
    Weight,
    sizeofSparseHeap,
    newSparseHeap,
    cloneSparseHeap,
    memberSparseHeap,
    insertSparseHeap,
    deleteSparseHeap,
    popSparseHeap,
    popSparseHeap_,
    elemsSparseHeap,
    clearSparseHeap,
    extendSparseHeap,
    drainSparseHeap,
    modifyWeightSparseHeap,
    scaleWeightsSparseHeap,
) where

import Data.Bits
import Data.Primitive.PrimVar

import PureSAT.Base
import PureSAT.Utils
import PureSAT.Prim

type Weight = Word

-- import Debug.Trace

-- #define CHECK_INVARIANTS

-- $setup
-- >>> import Control.Monad.ST (runST)

-- | Like sparse set https://research.swtch.com/sparse,
-- but also a max heap https://en.wikipedia.org/wiki/Heap_(data_structure)
--
-- i.e. pop returns minimum element.
--
data SparseHeap s = SH
    { forall s. SparseHeap s -> PrimVar s Int
size   :: {-# UNPACK #-} !(PrimVar s Int)
    , forall s. SparseHeap s -> MutablePrimArray s Int
dense  :: {-# UNPACK #-} !(MutablePrimArray s Int)
    , forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: {-# UNPACK #-} !(MutablePrimArray s Int)
    , forall s. SparseHeap s -> MutablePrimArray s Weight
weight :: {-# UNPACK #-} !(MutablePrimArray s Word)
    }

le :: Int -> Weight -> Int -> Weight -> Bool
le :: Int -> Weight -> Int -> Weight -> Bool
le Int
_ !Weight
u Int
_y !Weight
v = Weight
u Weight -> Weight -> Bool
forall a. Ord a => a -> a -> Bool
>= Weight
v
{-
le !x !u !y !v = u >= v
    | u > v          = True
    | u == v, x <= y = True
    | otherwise      = False
-}

checking :: String -> SparseHeap s -> ST s a -> ST s a
{-# INLINE checking #-}

#ifdef CHECK_INVARIANTS

#define CHECK(tag,heap) _invariant tag heap

checking tag heap m = do
    _invariant (tag ++ " pre") heap
    x <- m
    _invariant (tag ++ " post") heap
    return x

#else

#define CHECK(tag,heap)

checking :: forall s a. String -> SparseHeap s -> ST s a -> ST s a
checking String
_tag SparseHeap s
_heap ST s a
m = ST s a
m

#endif

_invariant :: String -> SparseHeap s -> ST s ()
_invariant :: forall s. String -> SparseHeap s -> ST s ()
_invariant String
tag SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} = do
    n         <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    capacity  <- getSizeofMutablePrimArray dense
    capacity1 <- getSizeofMutablePrimArray sparse
    capacity2 <- getSizeofMutablePrimArray weight

    unless (n <= capacity && capacity == capacity1 && capacity == capacity2) $
        error $ "capacities " ++ show (n, capacity, capacity1, capacity2)

    checkStructure capacity n 0
    checkHeaps n 0
  where
    checkStructure :: Int -> Int -> Int -> ST s ()
checkStructure Int
capacity Int
n Int
i =
        if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n
        then () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        else do
            x <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
i
            unless (x < capacity) $ error $ "x < capacity" ++ show (x, capacity)
            j <- readPrimArray sparse x
            unless (i == j) $ error $ "i == j" ++ show (i, j)
            checkStructure capacity n (i + 1)

    checkHeaps :: Int -> Int -> ST s ()
checkHeaps Int
n Int
i =
        if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n
        then () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        else do
            x <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
i
            u <- readPrimArray weight x
            heap n i x u
            checkHeaps n (i + 1)

    heap :: Int -> Int -> Int -> Weight -> ST s ()
heap Int
n Int
i Int
x Weight
u = do
        let !j :: Int
j = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        let !k :: Int
k = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2

        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            y <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
j
            v <- readPrimArray weight y
            unless (le x u y v) $ error $ "heap 1 " ++ tag ++ " " ++ show (n, i, x, u, j, y, v)

        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            z <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
k
            w <- readPrimArray weight z
            unless (le x u z w) $ error $ "heap 2 " ++ tag ++ " " ++ show (n, i, x, u, k, z, w)

-- | Create new sparse heap.
--
-- >>> runST $ newSparseHeap 100 >>= elemsSparseHeap
-- []
--
newSparseHeap
    :: Int -- ^ max integer
    -> ST s (SparseHeap s)
newSparseHeap :: forall s. Int -> ST s (SparseHeap s)
newSparseHeap !Int
capacity' = do
    let !capacity :: Int
capacity = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
1024 Int
capacity'
    size <- Int -> ST s (PrimVar (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar Int
0
    dense <- newPrimArray capacity
    sparse <- newPrimArray capacity
    weight <- newPrimArray capacity
    setPrimArray weight 0 capacity 0

    return SH {..}

cloneSparseHeap :: SparseHeap s -> ST s (SparseHeap s)
cloneSparseHeap :: forall s. SparseHeap s -> ST s (SparseHeap s)
cloneSparseHeap SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} = do
    capacity <- MutablePrimArray (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
getSizeofMutablePrimArray MutablePrimArray s Int
MutablePrimArray (PrimState (ST s)) Int
dense
    
    size'   <- readPrimVar size >>= newPrimVar
    dense'  <- resizeMutablePrimArray dense capacity
    sparse' <- resizeMutablePrimArray sparse capacity
    weight' <- resizeMutablePrimArray weight capacity

    copyMutablePrimArray dense' 0 dense 0 capacity
    copyMutablePrimArray sparse' 0 sparse 0 capacity
    copyMutablePrimArray weight' 0 weight 0 capacity

    return SH { size = size', dense = dense', sparse = sparse', weight = weight' }

-- | Size of sparse heap.
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insertSparseHeap set) [3,5,7,11,13,11]; sizeofSparseHeap set }
-- 5
--
sizeofSparseHeap :: SparseHeap s -> ST s Int
sizeofSparseHeap :: forall s. SparseHeap s -> ST s Int
sizeofSparseHeap SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} = PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size

-- | Extend sparse heap to fit new capacity.
extendSparseHeap
    :: Int -- ^ new capacity
    -> SparseHeap s
    -> ST s (SparseHeap s)
extendSparseHeap :: forall s. Int -> SparseHeap s -> ST s (SparseHeap s)
extendSparseHeap Int
capacity1 heap :: SparseHeap s
heap@SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} = do
    capacity2 <- MutablePrimArray (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
getSizeofMutablePrimArray MutablePrimArray s Int
MutablePrimArray (PrimState (ST s)) Int
dense
    let capacity = Int -> Int
nextPowerOf2 (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
capacity2 Int
capacity1)

    if capacity <= capacity2
    then return heap
    else do

        dense' <- resizeMutablePrimArray dense capacity
        sparse' <- resizeMutablePrimArray sparse capacity
        weight' <- resizeMutablePrimArray weight capacity
        setPrimArray weight' capacity2 (capacity - capacity2) 0

        return SH { size, dense = dense', sparse = sparse', weight = weight' }

-- | Test for membership.
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insertSparseHeap set) [3,5,7,11,13,11]; memberSparseHeap set 10 }
-- False
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insertSparseHeap set) [3,5,7,11,13,11]; memberSparseHeap set 13 }
-- True
--
memberSparseHeap :: SparseHeap s -> Int -> ST s Bool
memberSparseHeap :: forall s. SparseHeap s -> Int -> ST s Bool
memberSparseHeap heap :: SparseHeap s
heap@SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} Int
x = String -> SparseHeap s -> ST s Bool -> ST s Bool
forall s a. String -> SparseHeap s -> ST s a -> ST s a
checking String
"member" SparseHeap s
heap (ST s Bool -> ST s Bool) -> ST s Bool -> ST s Bool
forall a b. (a -> b) -> a -> b
$ do
    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    i <- readPrimArray sparse x
    if 0 <= i && i < n
    then do
        x' <- readPrimArray dense i
        return (x' == x)
    else return False

-- | Insert into the heap.
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insertSparseHeap set) [3,5,7,11,13,11]; elemsSparseHeap set }
-- [3,5,7,11,13]
--
insertSparseHeap :: SparseHeap s -> Int -> ST s ()
insertSparseHeap :: forall s. SparseHeap s -> Int -> ST s ()
insertSparseHeap heap :: SparseHeap s
heap@SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} Int
x = String -> SparseHeap s -> ST s () -> ST s ()
forall s a. String -> SparseHeap s -> ST s a -> ST s a
checking String
"insert" SparseHeap s
heap (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    i <- readPrimArray sparse x
    if 0 <= i && i < n
    then do
        x' <- readPrimArray dense i
        if x == x' then return () else insert n
    else insert n
  where
    {-# INLINE insert #-}
    insert :: Int -> ST s ()
insert !Int
n = do
        MutablePrimArray s Int -> Int -> Int -> ST s ()
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> a -> ST s ()
writePrimArray MutablePrimArray s Int
dense Int
n Int
x
        MutablePrimArray s Int -> Int -> Int -> ST s ()
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> a -> ST s ()
writePrimArray MutablePrimArray s Int
sparse Int
x Int
n
        PrimVar (PrimState (ST s)) Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        u <- MutablePrimArray s Weight -> Int -> ST s Weight
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Weight
weight Int
x
        swim (n + 1) dense sparse weight n x u

-- | Delete element from the heap.
--
-- >>> runST $ do { set <- newSparseHeap 100; deleteSparseHeap set 10; elemsSparseHeap set }
-- []
--
-- >>> let insert heap x = modifyWeightSparseHeap heap x (\_ -> fromIntegral $ 100 - x) >> insertSparseHeap heap x;
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insert set) [3,5,7,11,13,11]; deleteSparseHeap set 10; elemsSparseHeap set }
-- [3,5,7,11,13]
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insert set) [3,5,7,11,13,11]; deleteSparseHeap set 13; elemsSparseHeap set }
-- [3,5,7,11]
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insert set) [3,5,7,11,13,11]; deleteSparseHeap set 11; elemsSparseHeap set }
-- [3,5,7,13]
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insert set) [3,5,7,11,13,11]; deleteSparseHeap set 3; elemsSparseHeap set }
-- [5,11,7,13]
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insert set) $ [0,2..20] ++ [19,17..3]; deleteSparseHeap set 10; elemsSparseHeap set }
-- [0,2,4,5,3,17,12,9,6,8,20,19,18,15,13,14,11,16,7]
--
deleteSparseHeap :: SparseHeap s -> Int -> ST s ()
deleteSparseHeap :: forall s. SparseHeap s -> Int -> ST s ()
deleteSparseHeap heap :: SparseHeap s
heap@SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} Int
x = String -> SparseHeap s -> ST s () -> ST s ()
forall s a. String -> SparseHeap s -> ST s a -> ST s a
checking String
"delete" SparseHeap s
heap (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    i <- readPrimArray sparse x
    if 0 <= i && i < n
    then do
        x' <- readPrimArray dense i
        if x == x' then delete i n else return ()
    else return ()
  where
    {-# INLINE delete #-}
    delete :: Int -> Int -> ST s ()
delete !Int
i !Int
n = do
        let !n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
        PrimVar (PrimState (ST s)) Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size Int
n'

        if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n'
        then () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        else Int -> Int -> ST s ()
swimSink Int
n' Int
i

    -- to delete element we swim it up, as if it had maximum weight, and then we pop it
    swimSink :: Int -> Int -> ST s ()
swimSink Int
n Int
i
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
        = do
            -- j = floor (i - 1 / 2)
            let !j :: Int
j = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
1
            y <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
j
            swap' dense sparse i x j y
            swimSink n j

        | Bool
otherwise -- i == 0
        = do
            let j :: Int
j = Int
n
            PrimVar (PrimState (ST s)) Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size Int
j

            y <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
j
            v <- readPrimArray weight y
            swap' dense sparse 0 x j y
            sink j dense sparse weight 0 y v


{-# INLINE swap' #-}
swap' :: MutablePrimArray s Int -> MutablePrimArray s Int -> Int -> Int -> Int -> Int -> ST s ()
swap' :: forall s.
MutablePrimArray s Int
-> MutablePrimArray s Int -> Int -> Int -> Int -> Int -> ST s ()
swap' !MutablePrimArray s Int
dense !MutablePrimArray s Int
sparse !Int
i !Int
x !Int
j !Int
y  = do
    MutablePrimArray s Int -> Int -> Int -> ST s ()
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> a -> ST s ()
writePrimArray MutablePrimArray s Int
dense Int
j Int
x
    MutablePrimArray s Int -> Int -> Int -> ST s ()
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> a -> ST s ()
writePrimArray MutablePrimArray s Int
dense Int
i Int
y
    MutablePrimArray s Int -> Int -> Int -> ST s ()
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> a -> ST s ()
writePrimArray MutablePrimArray s Int
sparse Int
x Int
j
    MutablePrimArray s Int -> Int -> Int -> ST s ()
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> a -> ST s ()
writePrimArray MutablePrimArray s Int
sparse Int
y Int
i

-- sift down
sink :: Int -> MutablePrimArray s Int -> MutablePrimArray s Int -> MutablePrimArray s Weight  -> Int -> Int -> Weight -> ST s ()
sink :: forall s.
Int
-> MutablePrimArray s Int
-> MutablePrimArray s Int
-> MutablePrimArray s Weight
-> Int
-> Int
-> Weight
-> ST s ()
sink !Int
n !MutablePrimArray s Int
dense !MutablePrimArray s Int
sparse !MutablePrimArray s Weight
weight !Int
i !Int
x !Weight
u
    | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
    = do
        l <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
j
        r <- readPrimArray dense k
        v <- readPrimArray weight l
        w <- readPrimArray weight r

        -- traceM $ "sink" ++ show ((i, x, u), (j, l, v), (k, r, w))

        if le x u l v -- x <= l
        then do
            if le x u r w -- x <= r
            then return ()
            else do
                 -- r < x <= l; swap x and r
                swap' dense sparse i x k r
                sink n dense sparse weight k x u

        else do
            if le l v r w -- l <= r
            then do
                -- l < x, l <= r; swap x and l
                swap' dense sparse i x j l
                sink n dense sparse weight j x u

            else do
                -- r < l <= x; swap x and r
                swap' dense sparse i x k r
                sink n dense sparse weight k x u

    | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
    = do
        l <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
j
        v <- readPrimArray weight l
        if le x u l v -- x <= l
        then return ()
        else do
            swap' dense sparse i x j l
            -- no need to sink further, as we sinked to the last element.

    | Bool
otherwise
    = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    !j :: Int
j = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    !k :: Int
k = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

-- sift up
swim :: Int -> MutablePrimArray s Int -> MutablePrimArray s Int -> MutablePrimArray s Weight  -> Int -> Int -> Weight -> ST s ()
swim :: forall s.
Int
-> MutablePrimArray s Int
-> MutablePrimArray s Int
-> MutablePrimArray s Weight
-> Int
-> Int
-> Weight
-> ST s ()
swim !Int
_n !MutablePrimArray s Int
dense !MutablePrimArray s Int
sparse !MutablePrimArray s Weight
weight !Int
i !Int
x !Weight
u
    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
    = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    | Bool
otherwise
    = do
        -- j = floor (i - 1 / 2)
        let !j :: Int
j = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
1
        y <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
j
        v <- readPrimArray weight y

        unless (le y v x u) $ do
            swap' dense sparse i x j y
            swim _n dense sparse weight j x u

-- | Modify weight of the element.
--
-- >>> let insert heap x = modifyWeightSparseHeap heap x (\_ -> fromIntegral $ 100 - x) >> insertSparseHeap heap x;
-- >>> let populate heap = mapM_ (insert heap) [5,3,7,11,13,11]
-- >>> let populate' heap = mapM_ (insertSparseHeap heap) [5,3,7,11,13,11]
--
-- >>> runST $ do { heap <- newSparseHeap 100; populate heap; popSparseHeap heap }
-- Just 3
--
-- >>> runST $ do { heap <- newSparseHeap 100; populate heap; modifyWeightSparseHeap heap 3 (\_ -> 0); popSparseHeap heap }
-- Just 5
--
-- Weight are preserved even if element is not in the heap at the moment
--
-- >>> runST $ do { heap <- newSparseHeap 100; modifyWeightSparseHeap heap 7 (\_ -> 100); populate' heap; popSparseHeap heap }
-- Just 7
--
modifyWeightSparseHeap :: forall s. SparseHeap s -> Int -> (Weight -> Weight) -> ST s ()
modifyWeightSparseHeap :: forall s. SparseHeap s -> Int -> (Weight -> Weight) -> ST s ()
modifyWeightSparseHeap heap :: SparseHeap s
heap@SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} !Int
x Weight -> Weight
f = String -> SparseHeap s -> ST s () -> ST s ()
forall s a. String -> SparseHeap s -> ST s a -> ST s a
checking String
"modify" SparseHeap s
heap (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
    u' <- MutablePrimArray s Weight -> Int -> ST s Weight
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Weight
weight Int
x
    let !u = Weight -> Weight
f Weight
u'
    writePrimArray weight x u

    if u == u'
    then return ()
    else do
        n <- readPrimVar size
        i <- readPrimArray sparse x
        if 0 <= i && i < n
        then do
            x' <- readPrimArray dense i
            if x == x' then balance n i u u' else return ()
        else return ()
  where
    balance :: Int -> Int -> Weight -> Weight -> ST s ()
    balance :: Int -> Int -> Weight -> Weight -> ST s ()
balance !Int
n !Int
i !Weight
u !Weight
u'
        | Weight
u Weight -> Weight -> Bool
forall a. Ord a => a -> a -> Bool
>= Weight
u'
        = Int
-> MutablePrimArray s Int
-> MutablePrimArray s Int
-> MutablePrimArray s Weight
-> Int
-> Int
-> Weight
-> ST s ()
forall s.
Int
-> MutablePrimArray s Int
-> MutablePrimArray s Int
-> MutablePrimArray s Weight
-> Int
-> Int
-> Weight
-> ST s ()
swim Int
n MutablePrimArray s Int
dense MutablePrimArray s Int
sparse MutablePrimArray s Weight
weight Int
i Int
x Weight
u

        | Bool
otherwise
        = Int
-> MutablePrimArray s Int
-> MutablePrimArray s Int
-> MutablePrimArray s Weight
-> Int
-> Int
-> Weight
-> ST s ()
forall s.
Int
-> MutablePrimArray s Int
-> MutablePrimArray s Int
-> MutablePrimArray s Weight
-> Int
-> Int
-> Weight
-> ST s ()
sink Int
n MutablePrimArray s Int
dense MutablePrimArray s Int
sparse MutablePrimArray s Weight
weight Int
i Int
x Weight
u
{-# INLINE modifyWeightSparseHeap #-}

scaleWeightsSparseHeap :: forall s. SparseHeap s -> (Weight -> Weight) -> ST s ()
scaleWeightsSparseHeap :: forall s. SparseHeap s -> (Weight -> Weight) -> ST s ()
scaleWeightsSparseHeap heap :: SparseHeap s
heap@SH{PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} Weight -> Weight
f = String -> SparseHeap s -> ST s () -> ST s ()
forall s a. String -> SparseHeap s -> ST s a -> ST s a
checking String
"scale" SparseHeap s
heap (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
    !capacity <- MutablePrimArray (PrimState (ST s)) Weight -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
getSizeofMutablePrimArray MutablePrimArray s Weight
MutablePrimArray (PrimState (ST s)) Weight
weight
    go capacity 0
  where
    go :: Int -> Int -> ST s ()
go !Int
n !Int
i
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n    = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Bool
otherwise = do
            u <- MutablePrimArray s Weight -> Int -> ST s Weight
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Weight
weight Int
i
            writePrimArray weight i (f u)

-- | Pop element from the heap.
--
-- >>> let insert heap x = modifyWeightSparseHeap heap x (\_ -> - fromIntegral x) >> insertSparseHeap heap x;
--
-- >>> runST $ do { heap <- newSparseHeap 100; mapM_ (insert heap) [5,3,7,11,13,11]; popSparseHeap heap }
-- Just 3
--
-- >>> runST $ do { heap <- newSparseHeap 500; mapM_ (insert heap) [1..400]; drainSparseHeap heap }
-- [1,2...,400]
popSparseHeap :: SparseHeap s -> ST s (Maybe Int)
popSparseHeap :: forall s. SparseHeap s -> ST s (Maybe Int)
popSparseHeap SparseHeap s
heap = SparseHeap s
-> ST s (Maybe Int)
-> (Int -> ST s (Maybe Int))
-> ST s (Maybe Int)
forall s r. SparseHeap s -> ST s r -> (Int -> ST s r) -> ST s r
popSparseHeap_ SparseHeap s
heap (Maybe Int -> ST s (Maybe Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Int
forall a. Maybe a
Nothing) (Maybe Int -> ST s (Maybe Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Int -> ST s (Maybe Int))
-> (Int -> Maybe Int) -> Int -> ST s (Maybe Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Maybe Int
forall a. a -> Maybe a
Just)

{-# INLINE popSparseHeap_ #-}
popSparseHeap_ :: SparseHeap s -> ST s r -> (Int -> ST s r) -> ST s r
popSparseHeap_ :: forall s r. SparseHeap s -> ST s r -> (Int -> ST s r) -> ST s r
popSparseHeap_ _heap :: SparseHeap s
_heap@SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} ST s r
no Int -> ST s r
yes = do
    CHECK("pop pre", _heap)

    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size

    -- xs <- freezePrimArray dense 0 n
    -- traceM $ "pop" ++ show (take 15 $ primArrayToList xs)

    if n <= 0
    then no
    else do
        let !j = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
        writePrimVar size j

        x <- readPrimArray dense 0
        y <- readPrimArray dense j
        v <- readPrimArray weight y
        swap' dense sparse 0 x j y
        sink j dense sparse weight 0 y v

        CHECK("pop post", _heap)
        yes x

-- | Clear sparse heap.
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insertSparseHeap set) [3,5,7,11,13,11]; clearSparseHeap set; elemsSparseHeap set }
-- []
--
clearSparseHeap :: SparseHeap s -> ST s ()
clearSparseHeap :: forall s. SparseHeap s -> ST s ()
clearSparseHeap SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} = do
    PrimVar (PrimState (ST s)) Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size Int
0

-- | Elements of the heap.
--
-- Returns elements as they are internally stored.
--
elemsSparseHeap :: SparseHeap s -> ST s [Int]
elemsSparseHeap :: forall s. SparseHeap s -> ST s [Int]
elemsSparseHeap SH {PrimVar s Int
MutablePrimArray s Int
MutablePrimArray s Weight
size :: forall s. SparseHeap s -> PrimVar s Int
dense :: forall s. SparseHeap s -> MutablePrimArray s Int
sparse :: forall s. SparseHeap s -> MutablePrimArray s Int
weight :: forall s. SparseHeap s -> MutablePrimArray s Weight
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
weight :: MutablePrimArray s Weight
..} = do
    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    go [] 0 n
  where
    go :: [Int] -> Int -> Int -> ST s [Int]
go ![Int]
acc !Int
i !Int
n
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
        = do
            x <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
i
            go (x : acc) (i + 1) n

        | Bool
otherwise
        = [Int] -> ST s [Int]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
acc)

-- | Drain element from the heap.
--
-- >>> let insert heap x = modifyWeightSparseHeap heap x (\_ -> - fromIntegral x) >> insertSparseHeap heap x;
--
-- >>> runST $ do { set <- newSparseHeap 100; mapM_ (insert set) [3,5,7,11,13,11]; drainSparseHeap set }
-- [3,5,7,11,13]
--
drainSparseHeap :: SparseHeap s -> ST s [Int]
drainSparseHeap :: forall s. SparseHeap s -> ST s [Int]
drainSparseHeap SparseHeap s
heap = ([Int] -> [Int]) -> ST s [Int]
go [Int] -> [Int]
forall a. a -> a
id where
    go :: ([Int] -> [Int]) -> ST s [Int]
go [Int] -> [Int]
acc = SparseHeap s -> ST s [Int] -> (Int -> ST s [Int]) -> ST s [Int]
forall s r. SparseHeap s -> ST s r -> (Int -> ST s r) -> ST s r
popSparseHeap_ SparseHeap s
heap
        ([Int] -> ST s [Int]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> [Int]
acc []))
        (\Int
x -> ([Int] -> [Int]) -> ST s [Int]
go ([Int] -> [Int]
acc ([Int] -> [Int]) -> ([Int] -> [Int]) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int
x Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:)))