diff --git a/containers-tests/benchmarks/Map.hs b/containers-tests/benchmarks/Map.hs index b6f433d4b..736a8e3b4 100644 --- a/containers-tests/benchmarks/Map.hs +++ b/containers-tests/benchmarks/Map.hs @@ -18,6 +18,7 @@ import Data.Coerce import Data.Tuple.Solo (Solo (MkSolo), getSolo) import System.Random (StdGen, mkStdGen, random, randoms) import Prelude hiding (lookup) +import Utils.Containers.Internal.Strict (StrictPair(..)) import Utils.Fold (foldBenchmarks, foldWithKeyBenchmarks) import Utils.Random (shuffle) @@ -26,9 +27,10 @@ main = do let m = M.fromList elems :: M.Map Int Int m_even = M.fromList elems_even :: M.Map Int Int m_odd = M.fromList elems_odd :: M.Map Int Int + s_odd_keys = M.keysSet m_odd :: Set.Set Int s_random = Set.fromList keys_random :: Set.Set Int evaluate $ rnf [m, m_even, m_odd] - evaluate $ rnf [s_random] + evaluate $ rnf [s_random, s_odd_keys] evaluate $ rnf [elems_distinct_asc, elems_distinct_desc, elems_asc, elems_desc] evaluate $ rnf [keys_random] @@ -139,6 +141,7 @@ main = do , bench "Lazy.fromSetA inner" $ whnf (getSolo . M.fromSetA (MkSolo . pred)) s_random , bench "Strict.fromSetA inner" $ whnf (getSolo . MS.fromSetA (MkSolo . pred)) s_random , bench "minView" $ whnf (\m' -> case M.minViewWithKey m' of {Nothing -> 0; Just ((k,v),m'') -> k+v+M.size m''}) (M.fromAscList $ zip [1..10::Int] [100..110::Int]) + , bench "eq" $ whnf (\m' -> m' == m') m -- worst case, compares everything , bench "compare" $ whnf (\m' -> compare m' m') m -- worst case, compares everything , bgroup "folds" $ foldBenchmarks M.foldr M.foldl M.foldr' M.foldl' foldMap m @@ -148,6 +151,10 @@ main = do , bench "mapKeys:desc" $ whnf (M.mapKeys (negate . (+1))) m , bench "mapKeysWith:asc" $ whnf (M.mapKeysWith (+) (`div` 2)) m , bench "mapKeysWith:desc" $ whnf (M.mapKeysWith (+) (negate . (`div` 2))) m + + , bench "restrictKeys" $ whnf (M.restrictKeys m) s_odd_keys + , bench "withoutKeys" $ whnf (M.withoutKeys m) s_odd_keys + , bench "partitionKeys" $ whnf (M.partitionKeys m) s_odd_keys ] where bound = 2^14 diff --git a/containers-tests/containers-tests.cabal b/containers-tests/containers-tests.cabal index e71780c7c..b6a9f92cc 100644 --- a/containers-tests/containers-tests.cabal +++ b/containers-tests/containers-tests.cabal @@ -121,12 +121,12 @@ library Data.Tree Utils.Containers.Internal.BitQueue Utils.Containers.Internal.BitUtil + Utils.Containers.Internal.Strict other-modules: Utils.Containers.Internal.Prelude Utils.Containers.Internal.PtrEquality Utils.Containers.Internal.State - Utils.Containers.Internal.Strict Utils.Containers.Internal.EqOrdUtil if impl(ghc >= 8.6) diff --git a/containers-tests/tests/map-properties.hs b/containers-tests/tests/map-properties.hs index c42887ef0..b04a9ff71 100644 --- a/containers-tests/tests/map-properties.hs +++ b/containers-tests/tests/map-properties.hs @@ -181,6 +181,7 @@ main = defaultMain $ testGroup "map-properties" , testProperty "withoutKeys" prop_withoutKeys , testProperty "intersection" prop_intersection , testProperty "restrictKeys" prop_restrictKeys + , testProperty "partitionKeys" prop_partitionKeys , testProperty "intersection model" prop_intersectionModel , testProperty "intersectionWith" prop_intersectionWith , testProperty "intersectionWithModel" prop_intersectionWithModel @@ -1168,6 +1169,15 @@ prop_withoutKeys m s0 = valid reduced .&&. (m `withoutKeys` s === filterWithKey s = keysSet s0 reduced = withoutKeys m s +prop_partitionKeys :: IMap -> IMap -> Property +prop_partitionKeys m s0 = + valid with .&&. + valid without .&&. + (m `partitionKeys` s === (m `restrictKeys` s, m `withoutKeys` s)) + where + s = keysSet s0 + (with, without) = partitionKeys m s + prop_intersection :: IMap -> IMap -> Bool prop_intersection t1 t2 = valid (intersection t1 t2) diff --git a/containers/src/Data/Map/Internal.hs b/containers/src/Data/Map/Internal.hs index 6852092e1..3afb7f829 100644 --- a/containers/src/Data/Map/Internal.hs +++ b/containers/src/Data/Map/Internal.hs @@ -7,6 +7,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE Trustworthy #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ScopedTypeVariables #-} #define USE_MAGIC_PROXY 1 #endif @@ -300,6 +301,7 @@ module Data.Map.Internal ( , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey @@ -1933,6 +1935,50 @@ withoutKeys m (Set.Bin _ k ls rs) = case splitMember k m of !rm' = withoutKeys rm rs {-# INLINABLE withoutKeys #-} +-- | \(O\bigl(m \log\bigl(\frac{n}{m}+1\bigr)\bigr), \; 0 < m \leq n\). Partition the map according to a set. +-- The first map contains the input 'Map' restricted to those keys found in the 'Set', +-- the second map contains the input 'Map' without all keys in the 'Set'. +-- This is equivalent to using 'restrictKeys' and 'withoutKeys' together but is more efficient. +-- +-- @ +-- m \`partitionKeys\` s = (m ``restrictKeys`` s, m ``withoutKeys`` s) +-- @ +-- +-- @since FIXME +partitionKeys :: forall k a. Ord k => Map k a -> Set k -> (Map k a, Map k a) +partitionKeys xs ys = + case go xs ys of + xs' :*: ys' -> (xs', ys') + where + go :: Map k a -> Set k -> StrictPair (Map k a) (Map k a) + go Tip _ = Tip :*: Tip + go m Set.Tip = Tip :*: m + go m@(Bin _ k x lm rm) s@Set.Bin{} = + case b of + True -> with :*: without + where + with = + if lmWith `ptrEq` lm && rmWith `ptrEq` rm + then m + else link k x lmWith rmWith + without = + link2 lmWithout rmWithout + False -> with :*: without + where + with = link2 lmWith rmWith + without = + if lmWithout `ptrEq` lm && rmWithout `ptrEq` rm + then m + else link k x lmWithout rmWithout + where + !(lmWith :*: lmWithout) = go lm ls' + !(rmWith :*: rmWithout) = go rm rs' + + !(!ls', b, !rs') = Set.splitMember k s +#if __GLASGOW_HASKELL__ +{-# INLINABLE partitionKeys #-} +#endif + -- | \(O(n+m)\). Difference with a combining function. -- When two equal keys are -- encountered, the combining function is applied to the values of these keys. diff --git a/containers/src/Data/Map/Lazy.hs b/containers/src/Data/Map/Lazy.hs index d7b28e360..7542c8f07 100644 --- a/containers/src/Data/Map/Lazy.hs +++ b/containers/src/Data/Map/Lazy.hs @@ -247,6 +247,7 @@ module Data.Map.Lazy ( , filterWithKey , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey , takeWhileAntitone diff --git a/containers/src/Data/Map/Strict.hs b/containers/src/Data/Map/Strict.hs index 94476db69..8da370ff6 100644 --- a/containers/src/Data/Map/Strict.hs +++ b/containers/src/Data/Map/Strict.hs @@ -261,6 +261,7 @@ module Data.Map.Strict , filterWithKey , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey diff --git a/containers/src/Data/Map/Strict/Internal.hs b/containers/src/Data/Map/Strict/Internal.hs index 957c4bc4a..94df5760c 100644 --- a/containers/src/Data/Map/Strict/Internal.hs +++ b/containers/src/Data/Map/Strict/Internal.hs @@ -242,6 +242,7 @@ module Data.Map.Strict.Internal , filterWithKey , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey , takeWhileAntitone @@ -409,7 +410,8 @@ import Data.Map.Internal , toDescList , union , unions - , withoutKeys ) + , withoutKeys + , partitionKeys ) import Data.Map.Internal.Debug (valid)