Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not fusing unless monadic #438

Open
WinstonHartnett opened this issue Jun 23, 2022 · 2 comments
Open

Not fusing unless monadic #438

WinstonHartnett opened this issue Jun 23, 2022 · 2 comments

Comments

@WinstonHartnett
Copy link

WinstonHartnett commented Jun 23, 2022

From this issue on ghc:

The StackOverflow question "Is there any way to inline a recursive function" includes roughly the following trick by Matthew Pickering to inline the recursive function oldNTimes shown below:

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}

module Main (main) where

import GHC.TypeLits
import qualified Data.Vector.Unboxed as V

{-# INLINE incAll #-}
incAll :: V.Vector Int -> V.Vector Int
incAll = V.map (+ 1)

-- Old definition
oldNTimes :: Int -> (a -> a) -> a -> a
oldNTimes 0 f x = x
oldNTimes n f x = f (oldNTimes (n-1) f x)

-- New definition
class Unroll (n :: Nat) where
  nTimes :: (a -> a) -> a -> a

instance Unroll 0 where
  nTimes f x = x

instance {-# OVERLAPPABLE #-} Unroll (p - 1) => Unroll p where
  nTimes f x = f (nTimes @(p - 1) f x)

main :: IO ()
main = do
  let size = 100000000 :: Int
  let array = V.replicate size 0 :: V.Vector Int
  print $ V.sum (nTimes @64 incAll array)
  -- print $ V.sum (oldNTimes 64 incAll array)

On GHC 8.2.2, nTimes takes 38.1ms compared to oldNTimes' 25.5s. But on 9.2.2, this doesn't fuse and both nTimes and oldNTimes run in 4.3s (113x slower). GHC is inlining the recursive calls.

And, for some reason, lifted incAll and nTimes run in 38ms on 9.2.2.

...
{-# INLINE incAllM #-}
incAllM :: Monad m => V.Vector Int -> m (V.Vector Int)
incAllM = pure . V.map (+ 1)

class UnrollM (n :: Nat) where
  nTimesM :: Monad m => (a -> m a) -> a -> m a

instance UnrollM 0 where
  nTimesM f x = pure x

instance {-# OVERLAPPABLE #-} UnrollM (p - 1) => UnrollM p where
  nTimesM f x = f =<< nTimesM @(p - 1) f x

main :: IO ()
main = do
  let size = 100000000 :: Int
  let array = V.replicate size 0 :: V.Vector Int
  print . V.sum =<< nTimesM @64 incAllM array

Reproduction project here

@WinstonHartnett
Copy link
Author

WinstonHartnett commented Jun 23, 2022

Seems related to #416.

@WinstonHartnett
Copy link
Author

Actually, incAll is being inlined before specialization b/c it's a 0-arity binding (see ghc issue above). incAll x = V.map (+1) x restores performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant