{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Vector.Algorithms.Merge
( sort
, sortBy
, Comparison
) where
import Prelude hiding (read, length)
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common (Comparison, copyOffset, midPoint)
import qualified Data.Vector.Algorithms.Optimal as O
import qualified Data.Vector.Algorithms.Insertion as I
sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
sort :: v (PrimState m) e -> m ()
sort = Comparison e -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
sortBy Comparison e
forall a. Ord a => a -> a -> Ordering
compare
{-# INLINABLE sort #-}
sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
sortBy :: Comparison e -> v (PrimState m) e -> m ()
sortBy cmp :: Comparison e
cmp vec :: v (PrimState m) e
vec = if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 4
then if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 2
then if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 2
then () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
else Comparison e -> v (PrimState m) e -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> m ()
O.sort2ByOffset Comparison e
cmp v (PrimState m) e
vec 0
else if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 3
then Comparison e -> v (PrimState m) e -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> m ()
O.sort3ByOffset Comparison e
cmp v (PrimState m) e
vec 0
else Comparison e -> v (PrimState m) e -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> m ()
O.sort4ByOffset Comparison e
cmp v (PrimState m) e
vec 0
else if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
threshold
then Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
I.sortByBounds Comparison e
cmp v (PrimState m) e
vec 0 Int
len
else do v (PrimState m) e
buf <- Int -> m (v (PrimState m) e)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
new Int
halfLen
Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
mergeSortWithBuf Comparison e
cmp v (PrimState m) e
vec v (PrimState m) e
buf
where
len :: Int
len = v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
vec
halfLen :: Int
halfLen = (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` 2
{-# INLINE sortBy #-}
mergeSortWithBuf :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
mergeSortWithBuf :: Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
mergeSortWithBuf cmp :: Comparison e
cmp src :: v (PrimState m) e
src buf :: v (PrimState m) e
buf = Int -> Int -> m ()
loop 0 (v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
src)
where
loop :: Int -> Int -> m ()
loop l :: Int
l u :: Int
u
| Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
threshold = Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
I.sortByBounds Comparison e
cmp v (PrimState m) e
src Int
l Int
u
| Bool
otherwise = do Int -> Int -> m ()
loop Int
l Int
mid
Int -> Int -> m ()
loop Int
mid Int
u
Comparison e
-> v (PrimState m) e -> v (PrimState m) e -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e
-> v (PrimState m) e -> v (PrimState m) e -> Int -> m ()
merge Comparison e
cmp (Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
l Int
len v (PrimState m) e
src) v (PrimState m) e
buf (Int
mid Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l)
where len :: Int
len = Int
u Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l
mid :: Int
mid = Int -> Int -> Int
midPoint Int
u Int
l
{-# INLINE mergeSortWithBuf #-}
merge :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> v (PrimState m) e
-> Int -> m ()
merge :: Comparison e
-> v (PrimState m) e -> v (PrimState m) e -> Int -> m ()
merge cmp :: Comparison e
cmp src :: v (PrimState m) e
src buf :: v (PrimState m) e
buf mid :: Int
mid = do v (PrimState m) e -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) e
tmp v (PrimState m) e
lower
e
eTmp <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
tmp 0
e
eUpp <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
upper 0
v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> e -> Int -> m ()
loop v (PrimState m) e
tmp 0 e
eTmp v (PrimState m) e
upper 0 e
eUpp 0
where
lower :: v (PrimState m) e
lower = Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice 0 Int
mid v (PrimState m) e
src
upper :: v (PrimState m) e
upper = Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
mid (v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
src Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
mid) v (PrimState m) e
src
tmp :: v (PrimState m) e
tmp = Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice 0 Int
mid v (PrimState m) e
buf
wroteHigh :: v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> Int -> m ()
wroteHigh low :: v (PrimState m) e
low iLow :: Int
iLow eLow :: e
eLow high :: v (PrimState m) e
high iHigh :: Int
iHigh iIns :: Int
iIns
| Int
iHigh Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
high = v (PrimState m) e -> v (PrimState m) e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy (Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
iIns (v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
low Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iLow) v (PrimState m) e
src)
(Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
iLow (v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
low Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iLow) v (PrimState m) e
low)
| Bool
otherwise = do e
eHigh <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
high Int
iHigh
v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> e -> Int -> m ()
loop v (PrimState m) e
low Int
iLow e
eLow v (PrimState m) e
high Int
iHigh e
eHigh Int
iIns
wroteLow :: v (PrimState m) e
-> Int -> v (PrimState m) e -> Int -> e -> Int -> m ()
wroteLow low :: v (PrimState m) e
low iLow :: Int
iLow high :: v (PrimState m) e
high iHigh :: Int
iHigh eHigh :: e
eHigh iIns :: Int
iIns
| Int
iLow Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
low = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = do e
eLow <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
low Int
iLow
v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> e -> Int -> m ()
loop v (PrimState m) e
low Int
iLow e
eLow v (PrimState m) e
high Int
iHigh e
eHigh Int
iIns
loop :: v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> e -> Int -> m ()
loop !v (PrimState m) e
low !Int
iLow !e
eLow !v (PrimState m) e
high !Int
iHigh !e
eHigh !Int
iIns = case Comparison e
cmp e
eHigh e
eLow of
LT -> do v (PrimState m) e -> Int -> e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) e
src Int
iIns e
eHigh
v (PrimState m) e
-> Int -> e -> v (PrimState m) e -> Int -> Int -> m ()
wroteHigh v (PrimState m) e
low Int
iLow e
eLow v (PrimState m) e
high (Int
iHigh Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (Int
iIns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1)
_ -> do v (PrimState m) e -> Int -> e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) e
src Int
iIns e
eLow
v (PrimState m) e
-> Int -> v (PrimState m) e -> Int -> e -> Int -> m ()
wroteLow v (PrimState m) e
low (Int
iLow Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) v (PrimState m) e
high Int
iHigh e
eHigh (Int
iIns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1)
{-# INLINE merge #-}
threshold :: Int
threshold :: Int
threshold = 25
{-# INLINE threshold #-}