From ac9f4f3271547249984a7d789b57d16c1e408753 Mon Sep 17 00:00:00 2001
From: David Feuer <David.Feuer@gmail.com>
Date: Sun, 22 Aug 2021 15:48:16 -0400
Subject: [PATCH] Unzip Map and IntMap more efficiently

The previous `Map` (same for `IntMap` throughout) instance
would first map eagerly over tha `Map`, producing an entire
`Map` full of thunks to apply the `unzipWith` function. Then
it would build two more entire `Map`s full of thunks to select
components of each pair. Depending on inlining and such, the
resulting maps may or may not have contained selector thunks;
if not, they could leak memory. Fix that.

NOTE: This PR is an alternative to #163. This one preserves
the precise laziness properties of the previous implementation.
---
 semialign/semialign.cabal                |  5 +-
 semialign/src/Data/Semialign/Internal.hs | 67 +++++++++++++++++++++++-
 2 files changed, 66 insertions(+), 6 deletions(-)

diff --git a/semialign/semialign.cabal b/semialign/semialign.cabal
index 8a06793..3c0f5df 100644
--- a/semialign/semialign.cabal
+++ b/semialign/semialign.cabal
@@ -79,10 +79,7 @@ library
     , tagged                         >=0.8.6    && <0.9
     , unordered-containers           >=0.2.8.0  && <0.3
     , vector                         >=0.12.0.2 && <0.13
-
-  -- base shims
-  if !impl(ghc >=8.2)
-    build-depends: bifunctors >=5.5.4 && <5.6
+    , bifunctors                     >=5.5.4 && <5.6
 
   if !impl(ghc >=8.0)
     build-depends:
diff --git a/semialign/src/Data/Semialign/Internal.hs b/semialign/src/Data/Semialign/Internal.hs
index d19d223..524c84d 100644
--- a/semialign/src/Data/Semialign/Internal.hs
+++ b/semialign/src/Data/Semialign/Internal.hs
@@ -1,4 +1,5 @@
 {-# LANGUAGE CPP                        #-}
+{-# LANGUAGE DeriveFunctor              #-}
 {-# LANGUAGE FlexibleInstances          #-}
 {-# LANGUAGE FunctionalDependencies     #-}
 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
@@ -19,6 +20,7 @@ import qualified Prelude as Prelude
 
 import Control.Applicative               (ZipList (..), pure, (<$>))
 import Data.Bifunctor                    (Bifunctor (..))
+import Data.Biapplicative                (Biapplicative (..), traverseBia)
 import Data.Functor.Compose              (Compose (..))
 import Data.Functor.Identity             (Identity (..))
 import Data.Functor.Product              (Product (..))
@@ -577,7 +579,47 @@ instance (Ord k) => Align (Map k) where
 instance Ord k => Unalign (Map k) where
     unalign xs = (Map.mapMaybe justHere xs, Map.mapMaybe justThere xs)
 
-instance Ord k => Unzip (Map k) where unzip = unzipDefault
+-- A copy of (,) with a stricter bimap.
+newtype SBPair a b = SBPair { unSBPair :: (a, b) }
+
+instance Bifunctor SBPair where
+  bimap f g (SBPair (a, b)) = SBPair (f a, g b)
+
+instance Biapplicative SBPair where
+  bipure a b = SBPair (a, b)
+  biliftA2 f g (SBPair (a, b)) (SBPair (c, d)) =
+    SBPair (f a c, g b d)
+
+instance Ord k => Unzip (Map k) where
+  -- Map has a strict spine, so we have to build a whole one at
+  -- once. The default instance would first build an entire
+  -- Map filled with thunks, each of which will produce a pair,
+  -- and then build two maps, each filled with thunks to extract
+  -- a value from the pair. We instead build both maps at once,
+  -- each of which will be filled with selector thunks, along
+  -- with thunks (not in any Map) holding the applications of
+  -- `f`.
+  unzipWith f xs = (l, r)
+    where
+      ~(l, r) = unSBPair . traverseBia (SBPair . blah) $ xs
+      blah c = let
+        {-# NOINLINE fc #-} -- make sure the result of f c is shared,
+                            -- and that nothing weird happens to
+                            -- keep us from getting selector thunks.
+        {-# NOINLINE a #-} -- make sure we get selector thunks
+        {-# NOINLINE b #-}
+        fc = f c
+        ~(a, b) = fc
+        in (a, b)
+
+  unzip xs = (l, r)
+    where
+      ~(l, r) = unSBPair . traverseBia (SBPair . blah) $ xs
+      blah ab = let
+        {-# NOINLINE a #-} -- make sure we get selector thunks
+        {-# NOINLINE b #-}
+        ~(a, b) = ab
+        in (a, b)
 
 instance Ord k => Zip (Map k) where
     zipWith = Map.intersectionWith
@@ -601,7 +643,28 @@ instance Align IntMap where
 instance Unalign IntMap where
     unalign xs = (IntMap.mapMaybe justHere xs, IntMap.mapMaybe justThere xs)
 
-instance Unzip IntMap where unzip = unzipDefault
+instance Unzip IntMap where
+  -- See notes at the Map instance
+  unzipWith f xs = (l, r)
+    where
+      ~(l, r) = unSBPair . traverseBia (SBPair . blah) $ xs
+      blah c = let
+        {-# NOINLINE fc #-} -- make sure the result of f c is shared,
+                            -- and that nothing weird happens to
+                            -- keep us from getting selector thunks.
+        {-# NOINLINE a #-} -- make sure we get selector thunks
+        {-# NOINLINE b #-}
+        fc = f c
+        ~(a, b) = fc
+        in (a, b)
+  unzip xs = (l, r)
+    where
+      ~(l, r) = unSBPair . traverseBia (SBPair . blah) $ xs
+      blah ab = let
+        {-# NOINLINE a #-} -- make sure we get selector thunks
+        {-# NOINLINE b #-}
+        ~(a, b) = ab
+        in (a, b)
 
 instance Zip IntMap where
     zipWith = IntMap.intersectionWith