Skip to content

Commit 4af12df

Browse files
authored
More efficient Eq, Ord for Set, Map (#1017)
* Add tests and benchmarks. * Implement Eq and Ord using foldMap + iterator. Effect on benchmark times, using GHC 9.6.3: Set Int, eq: -61% Set Int, compare: -53% Map Int Int, eq: -68% Map Int Int, compare: -76%
1 parent 549d22b commit 4af12df

File tree

9 files changed

+213
-37
lines changed

9 files changed

+213
-37
lines changed

containers-tests/benchmarks/Map.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ main = do
9595
, bench "fromDistinctDescList" $ whnf M.fromDistinctDescList elems_rev
9696
, bench "fromDistinctDescList:fusion" $ whnf (\n -> M.fromDistinctDescList [(i,i) | i <- [n,n-1..1]]) bound
9797
, bench "minView" $ whnf (\m' -> case M.minViewWithKey m' of {Nothing -> 0; Just ((k,v),m'') -> k+v+M.size m''}) (M.fromAscList $ zip [1..10::Int] [100..110::Int])
98+
, bench "eq" $ whnf (\m' -> m' == m') m -- worst case, compares everything
99+
, bench "compare" $ whnf (\m' -> compare m' m') m -- worst case, compares everything
98100
]
99101
where
100102
bound = 2^12

containers-tests/benchmarks/Set.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ main = do
5555
, bench "member.powerSet (16)" $ whnf (\ s -> all (flip S.member s) s) (S.powerSet (S.fromList [1..16]))
5656
, bench "member.powerSet (17)" $ whnf (\ s -> all (flip S.member s) s) (S.powerSet (S.fromList [1..17]))
5757
, bench "member.powerSet (18)" $ whnf (\ s -> all (flip S.member s) s) (S.powerSet (S.fromList [1..18]))
58+
, bench "eq" $ whnf (\s' -> s' == s') s -- worst case, compares everything
59+
, bench "compare" $ whnf (\s' -> compare s' s') s -- worst case, compares everything
5860
]
5961
where
6062
bound = 2^12

containers-tests/containers-tests.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ library
124124
Utils.Containers.Internal.PtrEquality
125125
Utils.Containers.Internal.State
126126
Utils.Containers.Internal.StrictMaybe
127+
Utils.Containers.Internal.EqOrdUtil
127128

128129
if impl(ghc)
129130
other-modules:

containers-tests/tests/map-properties.hs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import Test.Tasty
3333
import Test.Tasty.HUnit
3434
import Test.Tasty.QuickCheck
3535
import Test.QuickCheck.Function (apply)
36-
import Test.QuickCheck.Poly (A, B)
36+
import Test.QuickCheck.Poly (A, B, OrdA)
3737
import Control.Arrow (first)
3838

3939
default (Int)
@@ -250,6 +250,8 @@ main = defaultMain $ testGroup "map-properties"
250250
, testProperty "splitAt" prop_splitAt
251251
, testProperty "lookupMin" prop_lookupMin
252252
, testProperty "lookupMax" prop_lookupMax
253+
, testProperty "eq" prop_eq
254+
, testProperty "compare" prop_compare
253255
]
254256

255257
{--------------------------------------------------------------------
@@ -1636,3 +1638,9 @@ prop_fromArgSet :: [(Int, Int)] -> Bool
16361638
prop_fromArgSet ys =
16371639
let xs = List.nubBy ((==) `on` fst) ys
16381640
in fromArgSet (Set.fromList $ List.map (uncurry Arg) xs) == fromList xs
1641+
1642+
prop_eq :: Map Int A -> Map Int A -> Property
1643+
prop_eq m1 m2 = (m1 == m2) === (toList m1 == toList m2)
1644+
1645+
prop_compare :: Map Int OrdA -> Map Int OrdA -> Property
1646+
prop_compare m1 m2 = compare m1 m2 === compare (toList m1) (toList m2)

containers-tests/tests/set-properties.hs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ main = defaultMain $ testGroup "set-properties"
110110
, testProperty "strict foldr" prop_strictFoldr'
111111
, testProperty "strict foldl" prop_strictFoldl'
112112
#endif
113+
, testProperty "eq" prop_eq
114+
, testProperty "compare" prop_compare
113115
]
114116

115117
-- A type with a peculiar Eq instance designed to make sure keys
@@ -730,3 +732,9 @@ prop_strictFoldr' m = whnfHasNoThunks (foldr' (:) [] m)
730732
prop_strictFoldl' :: Set Int -> Property
731733
prop_strictFoldl' m = whnfHasNoThunks (foldl' (flip (:)) [] m)
732734
#endif
735+
736+
prop_eq :: Set Int -> Set Int -> Property
737+
prop_eq s1 s2 = (s1 == s2) === (toList s1 == toList s2)
738+
739+
prop_compare :: Set Int -> Set Int -> Property
740+
prop_compare s1 s2 = compare s1 s2 === compare (toList s1) (toList s2)

containers/containers.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Library
8080
Utils.Containers.Internal.StrictMaybe
8181
Utils.Containers.Internal.PtrEquality
8282
Utils.Containers.Internal.Coercions
83+
Utils.Containers.Internal.EqOrdUtil
8384
if impl(ghc)
8485
other-modules:
8586
Utils.Containers.Internal.TypeError

containers/src/Data/Map/Internal.hs

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ import Utils.Containers.Internal.PtrEquality (ptrEq)
401401
import Utils.Containers.Internal.StrictPair
402402
import Utils.Containers.Internal.StrictMaybe
403403
import Utils.Containers.Internal.BitQueue
404+
import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..))
404405
#ifdef DEFINE_ALTERF_FALLBACK
405406
import Utils.Containers.Internal.BitUtil (wordSize)
406407
#endif
@@ -4118,6 +4119,31 @@ deleteFindMax t = case maxViewWithKey t of
41184119
Nothing -> (error "Map.deleteFindMax: can not return the maximal element of an empty map", Tip)
41194120
Just res -> res
41204121

4122+
{--------------------------------------------------------------------
4123+
Iterator
4124+
--------------------------------------------------------------------}
4125+
4126+
-- See Note [Iterator] in Data.Set.Internal
4127+
4128+
iterDown :: Map k a -> Stack k a -> Stack k a
4129+
iterDown (Bin _ kx x l r) stk = iterDown l (Push kx x r stk)
4130+
iterDown Tip stk = stk
4131+
4132+
-- Create an iterator from a Map, starting at the smallest key.
4133+
iterator :: Map k a -> Stack k a
4134+
iterator m = iterDown m Nada
4135+
4136+
-- Get the next key-value and the remaining iterator.
4137+
iterNext :: Stack k a -> Maybe (StrictPair (KeyValue k a) (Stack k a))
4138+
iterNext (Push kx x r stk) = Just $! KeyValue kx x :*: iterDown r stk
4139+
iterNext Nada = Nothing
4140+
{-# INLINE iterNext #-}
4141+
4142+
-- Whether there are no more key-values in the iterator.
4143+
iterNull :: Stack k a -> Bool
4144+
iterNull (Push _ _ _ _) = False
4145+
iterNull Nada = True
4146+
41214147
{--------------------------------------------------------------------
41224148
[balance l x r] balances two trees with value x.
41234149
The sizes of the trees should balance after decreasing the
@@ -4284,41 +4310,69 @@ bin k x l r
42844310

42854311

42864312
{--------------------------------------------------------------------
4287-
Eq converts the tree to a list. In a lazy setting, this
4288-
actually seems one of the faster methods to compare two trees
4289-
and it is certainly the simplest :-)
4313+
Eq
42904314
--------------------------------------------------------------------}
4315+
42914316
instance (Eq k,Eq a) => Eq (Map k a) where
4292-
t1 == t2 = (size t1 == size t2) && (toAscList t1 == toAscList t2)
4317+
m1 == m2 = liftEq2 (==) (==) m1 m2
4318+
{-# INLINABLE (==) #-}
42934319

4294-
{--------------------------------------------------------------------
4295-
Ord
4296-
--------------------------------------------------------------------}
4320+
-- | @since 0.5.9
4321+
instance Eq k => Eq1 (Map k) where
4322+
liftEq = liftEq2 (==)
4323+
{-# INLINE liftEq #-}
42974324

4298-
instance (Ord k, Ord v) => Ord (Map k v) where
4299-
compare m1 m2 = compare (toAscList m1) (toAscList m2)
4325+
-- | @since 0.5.9
4326+
instance Eq2 Map where
4327+
liftEq2 keq eq m1 m2 = size m1 == size m2 && sameSizeLiftEq2 keq eq m1 m2
4328+
{-# INLINE liftEq2 #-}
4329+
4330+
-- Assumes the maps are of equal size to skip the final check
4331+
sameSizeLiftEq2
4332+
:: (ka -> kb -> Bool) -> (a -> b -> Bool) -> Map ka a -> Map kb b -> Bool
4333+
sameSizeLiftEq2 keq eq m1 m2 =
4334+
case runEqM (foldMapWithKey f m1) (iterator m2) of e :*: _ -> e
4335+
where
4336+
f kx x = EqM $ \it -> case iterNext it of
4337+
Nothing -> False :*: it
4338+
Just (KeyValue ky y :*: it') -> (keq kx ky && eq x y) :*: it'
4339+
{-# INLINE sameSizeLiftEq2 #-}
43004340

43014341
{--------------------------------------------------------------------
4302-
Lifted instances
4342+
Ord
43034343
--------------------------------------------------------------------}
43044344

4305-
-- | @since 0.5.9
4306-
instance Eq2 Map where
4307-
liftEq2 eqk eqv m n =
4308-
size m == size n && liftEq (liftEq2 eqk eqv) (toList m) (toList n)
4345+
instance (Ord k, Ord v) => Ord (Map k v) where
4346+
compare m1 m2 = liftCmp2 compare compare m1 m2
4347+
{-# INLINABLE compare #-}
43094348

43104349
-- | @since 0.5.9
4311-
instance Eq k => Eq1 (Map k) where
4312-
liftEq = liftEq2 (==)
4350+
instance Ord k => Ord1 (Map k) where
4351+
liftCompare = liftCmp2 compare
4352+
{-# INLINE liftCompare #-}
43134353

43144354
-- | @since 0.5.9
43154355
instance Ord2 Map where
4316-
liftCompare2 cmpk cmpv m n =
4317-
liftCompare (liftCompare2 cmpk cmpv) (toList m) (toList n)
4356+
liftCompare2 = liftCmp2
4357+
{-# INLINE liftCompare2 #-}
4358+
4359+
liftCmp2
4360+
:: (ka -> kb -> Ordering)
4361+
-> (a -> b -> Ordering)
4362+
-> Map ka a
4363+
-> Map kb b
4364+
-> Ordering
4365+
liftCmp2 kcmp cmp m1 m2 = case runOrdM (foldMapWithKey f m1) (iterator m2) of
4366+
o :*: it -> o <> if iterNull it then EQ else LT
4367+
where
4368+
f kx x = OrdM $ \it -> case iterNext it of
4369+
Nothing -> GT :*: it
4370+
Just (KeyValue ky y :*: it') -> (kcmp kx ky <> cmp x y) :*: it'
4371+
{-# INLINE liftCmp2 #-}
43184372

4319-
-- | @since 0.5.9
4320-
instance Ord k => Ord1 (Map k) where
4321-
liftCompare = liftCompare2 compare
4373+
{--------------------------------------------------------------------
4374+
Lifted instances
4375+
--------------------------------------------------------------------}
43224376

43234377
-- | @since 0.5.9
43244378
instance Show2 Map where

containers/src/Data/Set/Internal.hs

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ import Control.DeepSeq (NFData(rnf))
252252

253253
import Utils.Containers.Internal.StrictPair
254254
import Utils.Containers.Internal.PtrEquality
255+
import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..))
255256

256257
#if __GLASGOW_HASKELL__
257258
import GHC.Exts ( build, lazy )
@@ -1272,19 +1273,90 @@ foldl'Stack f = go
12721273
{-# INLINE foldl'Stack #-}
12731274

12741275
{--------------------------------------------------------------------
1275-
Eq converts the set to a list. In a lazy setting, this
1276-
actually seems one of the faster methods to compare two trees
1277-
and it is certainly the simplest :-)
1276+
Iterator
12781277
--------------------------------------------------------------------}
1278+
1279+
-- Note [Iterator]
1280+
-- ~~~~~~~~~~~~~~~
1281+
-- Iteration, using a Stack as an iterator, is an efficient way to consume a Set
1282+
-- one element at a time. Alternately, this may be done by toAscList. toAscList
1283+
-- when consumed via List.foldr will rewrite to Set.foldr (thanks to rewrite
1284+
-- rules), which is quite efficient. However, sometimes that is not possible,
1285+
-- such as in the second arg of '==' or 'compare', where manifesting the list
1286+
-- cons cells is unavoidable and makes things slower.
1287+
--
1288+
-- Concretely, compare on Set Int using toAscList takes ~21% more time compared
1289+
-- to using Iterator, on GHC 9.6.3.
1290+
--
1291+
-- The heart of this implementation is the `iterDown` function. It walks down
1292+
-- the left spine of the tree, pushing the value and right child on the stack,
1293+
-- until a Tip is reached. The next value is now at the top of the stack. To get
1294+
-- to the value after that, `iterDown` is called again with the right child and
1295+
-- the remaining stack.
1296+
1297+
iterDown :: Set a -> Stack a -> Stack a
1298+
iterDown (Bin _ x l r) stk = iterDown l (Push x r stk)
1299+
iterDown Tip stk = stk
1300+
1301+
-- Create an iterator from a Set, starting at the smallest element.
1302+
iterator :: Set a -> Stack a
1303+
iterator s = iterDown s Nada
1304+
1305+
-- Get the next element and the remaining iterator.
1306+
iterNext :: Stack a -> Maybe (StrictPair a (Stack a))
1307+
iterNext (Push x r stk) = Just $! x :*: iterDown r stk
1308+
iterNext Nada = Nothing
1309+
{-# INLINE iterNext #-}
1310+
1311+
-- Whether there are no more elements in the iterator.
1312+
iterNull :: Stack a -> Bool
1313+
iterNull (Push _ _ _) = False
1314+
iterNull Nada = True
1315+
1316+
{--------------------------------------------------------------------
1317+
Eq
1318+
--------------------------------------------------------------------}
1319+
12791320
instance Eq a => Eq (Set a) where
1280-
t1 == t2 = (size t1 == size t2) && (toAscList t1 == toAscList t2)
1321+
s1 == s2 = liftEq (==) s1 s2
1322+
{-# INLINABLE (==) #-}
1323+
1324+
-- | @since 0.5.9
1325+
instance Eq1 Set where
1326+
liftEq eq s1 s2 = size s1 == size s2 && sameSizeLiftEq eq s1 s2
1327+
{-# INLINE liftEq #-}
1328+
1329+
-- Assumes the sets are of equal size to skip the final check.
1330+
sameSizeLiftEq :: (a -> b -> Bool) -> Set a -> Set b -> Bool
1331+
sameSizeLiftEq eq s1 s2 =
1332+
case runEqM (foldMap f s1) (iterator s2) of e :*: _ -> e
1333+
where
1334+
f x = EqM $ \it -> case iterNext it of
1335+
Nothing -> False :*: it
1336+
Just (y :*: it') -> eq x y :*: it'
1337+
{-# INLINE sameSizeLiftEq #-}
12811338

12821339
{--------------------------------------------------------------------
12831340
Ord
12841341
--------------------------------------------------------------------}
12851342

12861343
instance Ord a => Ord (Set a) where
1287-
compare s1 s2 = compare (toAscList s1) (toAscList s2)
1344+
compare s1 s2 = liftCmp compare s1 s2
1345+
{-# INLINABLE compare #-}
1346+
1347+
-- | @since 0.5.9
1348+
instance Ord1 Set where
1349+
liftCompare = liftCmp
1350+
{-# INLINE liftCompare #-}
1351+
1352+
liftCmp :: (a -> b -> Ordering) -> Set a -> Set b -> Ordering
1353+
liftCmp cmp s1 s2 = case runOrdM (foldMap f s1) (iterator s2) of
1354+
o :*: it -> o <> if iterNull it then EQ else LT
1355+
where
1356+
f x = OrdM $ \it -> case iterNext it of
1357+
Nothing -> GT :*: it
1358+
Just (y :*: it') -> cmp x y :*: it'
1359+
{-# INLINE liftCmp #-}
12881360

12891361
{--------------------------------------------------------------------
12901362
Show
@@ -1293,16 +1365,6 @@ instance Show a => Show (Set a) where
12931365
showsPrec p xs = showParen (p > 10) $
12941366
showString "fromList " . shows (toList xs)
12951367

1296-
-- | @since 0.5.9
1297-
instance Eq1 Set where
1298-
liftEq eq m n =
1299-
size m == size n && liftEq eq (toList m) (toList n)
1300-
1301-
-- | @since 0.5.9
1302-
instance Ord1 Set where
1303-
liftCompare cmp m n =
1304-
liftCompare cmp (toList m) (toList n)
1305-
13061368
-- | @since 0.5.9
13071369
instance Show1 Set where
13081370
liftShowsPrec sp sl d m =
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{-# LANGUAGE CPP #-}
2+
module Utils.Containers.Internal.EqOrdUtil
3+
( EqM(..)
4+
, OrdM(..)
5+
) where
6+
7+
#if !MIN_VERSION_base(4,11,0)
8+
import Data.Semigroup (Semigroup(..))
9+
#endif
10+
import Utils.Containers.Internal.StrictPair
11+
12+
newtype EqM a = EqM { runEqM :: a -> StrictPair Bool a }
13+
14+
-- | Composes left-to-right, short-circuits on False
15+
instance Semigroup (EqM a) where
16+
f <> g = EqM $ \x -> case runEqM f x of
17+
r@(e :*: x') -> if e then runEqM g x' else r
18+
19+
instance Monoid (EqM a) where
20+
mempty = EqM (True :*:)
21+
#if !MIN_VERSION_base(4,11,0)
22+
mappend = (<>)
23+
#endif
24+
25+
newtype OrdM a = OrdM { runOrdM :: a -> StrictPair Ordering a }
26+
27+
-- | Composes left-to-right, short-circuits on non-EQ
28+
instance Semigroup (OrdM a) where
29+
f <> g = OrdM $ \x -> case runOrdM f x of
30+
r@(o :*: x') -> case o of
31+
EQ -> runOrdM g x'
32+
_ -> r
33+
34+
instance Monoid (OrdM a) where
35+
mempty = OrdM (EQ :*:)
36+
#if !MIN_VERSION_base(4,11,0)
37+
mappend = (<>)
38+
#endif

0 commit comments

Comments
 (0)