Skip to content

Commit 980c6f6

Browse files
authored
Inline anything (#431)
* Allow inlining of any function, do check dynamically * Add test RuntimeCast.agda relying on advanced inlining * [ #431 ] Eta-expand definition of rTail in EraseType.agda test case * [ #431 ] Comment out partial application examples in Inlining.agda test case
1 parent 41a090d commit 980c6f6

File tree

17 files changed

+167
-48
lines changed

17 files changed

+167
-48
lines changed

src/Agda2Hs/Compile.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import Data.IORef
1111
import Data.List ( isPrefixOf, group, sort )
1212

1313
import qualified Data.Map as M
14+
import qualified Data.Set as S
1415

1516
import Agda.Compiler.Backend
1617
import Agda.Compiler.Common ( curIF )
@@ -41,7 +42,8 @@ globalSetup :: Options -> TCM GlobalEnv
4142
globalSetup opts = do
4243
opts <- checkConfig opts
4344
ctMap <- liftIO $ newIORef M.empty
44-
return $ GlobalEnv opts ctMap
45+
ilMap <- liftIO $ newIORef S.empty
46+
return $ GlobalEnv opts ctMap ilMap
4547

4648
initCompileEnv :: GlobalEnv -> TopLevelModuleName -> SpecialRules -> CompileEnv
4749
initCompileEnv genv tlm rewrites = CompileEnv

src/Agda2Hs/Compile/Function.hs

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -394,31 +394,11 @@ checkTransparentPragma def = compileFun False def >>= \case
394394
"A transparent function must have exactly one non-erased argument and return it unchanged."
395395

396396

397-
-- | Ensure a definition can be defined as inline.
397+
-- | Mark a definition as one that should be inlined.
398398
checkInlinePragma :: Definition -> C ()
399-
checkInlinePragma def@Defn{defName = f} = do
400-
let Function{funClauses = cs} = theDef def
401-
case filter (isJust . clauseBody) cs of
402-
[c] ->
403-
unlessM (allowedPats (namedClausePats c)) $ agda2hsErrorM $
404-
"Cannot make function" <+> prettyTCM (defName def) <+> "inlinable." <+>
405-
"Inline functions can only use variable patterns or transparent record constructor patterns."
406-
_ ->
407-
agda2hsErrorM $
408-
"Cannot make function" <+> prettyTCM f <+> "inlinable." <+>
409-
"An inline function must have exactly one clause."
410-
411-
where allowedPat :: DeBruijnPattern -> C Bool
412-
allowedPat VarP{} = pure True
413-
-- only allow matching on (unboxed) record constructors
414-
allowedPat (ConP ch ci cargs) =
415-
isUnboxConstructor (conName ch) >>= \case
416-
Just _ -> allowedPats cargs
417-
Nothing -> pure False
418-
allowedPat _ = pure False
419-
420-
allowedPats :: NAPs -> C Bool
421-
allowedPats pats = allM (allowedPat . dget . dget) pats
399+
checkInlinePragma def@(Defn { defName = q , theDef = df }) = do
400+
let qs = fromMaybe [] $ getMutual_ df
401+
addInlineSymbols $ q : qs
422402

423403
checkCompileToFunctionPragma :: Definition -> String -> C ()
424404
checkCompileToFunctionPragma def s = noCheckNames $ do

src/Agda2Hs/Compile/Term.hs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ compileDef f ty args | Just sem <- isSpecialDef f = do
171171
sem ty args
172172

173173
compileDef f ty args =
174-
ifM (isTransparentFunction f) (compileErasedApp ty args) $
175-
ifM (isInlinedFunction f) (compileInlineFunctionApp f ty args) $ do
174+
ifM (isTransparentFunction f) (compileErasedApp ty args) $ do
175+
176176
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of regular function:" <+> prettyTCM f
177177

178178
let defMod = qnameModule f
@@ -458,14 +458,18 @@ compileTerm ty v = do
458458

459459
v <- instantiate v
460460

461+
toInline <- getInlineSymbols
462+
v <- locallyReduceDefs (OnlyReduceDefs toInline) $ reduce v
463+
461464
let bad s t = agda2hsErrorM $ vcat
462465
[ text "cannot compile" <+> text (s ++ ":")
463466
, nest 2 $ prettyTCM t
464467
]
465468

466469
reduceProjectionLike v >>= \case
467470

468-
Def f es -> do
471+
v@(Def f es) -> do
472+
whenM (isInlinedFunction f) $ bad "inlined function" v
469473
ty <- defType <$> getConstInfo f
470474
compileSpined (compileDef f ty) (Def f) ty es
471475

src/Agda2Hs/Compile/Types.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ data GlobalEnv = GlobalEnv
2929
{ globalOptions :: Options
3030
, compileToMap :: IORef (Map QName QName)
3131
-- ^ names with a compile-to pragma
32+
, inlineSymbols :: IORef (Set QName)
33+
-- ^ names of functions that should be inlined
3234
}
3335

3436
type ModuleEnv = TopLevelModuleName

src/Agda2Hs/Compile/Utils.hs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import Data.List ( isPrefixOf, stripPrefix )
1212
import Data.Maybe ( isJust )
1313
import qualified Data.Map as M
1414
import Data.String ( IsString(..) )
15+
import Data.Set ( Set )
16+
import qualified Data.Set as S
1517

1618
import GHC.Stack (HasCallStack)
1719

@@ -281,8 +283,28 @@ isTupleProjection q =
281283
isTransparentFunction :: QName -> C Bool
282284
isTransparentFunction q = (== TransparentPragma) <$> getPragma q
283285

286+
getInlineSymbols :: C (Set QName)
287+
getInlineSymbols = do
288+
ilSetRef <- asks $ inlineSymbols . globalEnv
289+
liftIO $ readIORef ilSetRef
290+
291+
debugInlineSymbols :: C ()
292+
debugInlineSymbols = do
293+
ilSetRef <- asks $ inlineSymbols . globalEnv
294+
ilSet <- liftIO $ readIORef ilSetRef
295+
reportSDoc "agda2hs.compile.inline" 50 $ text $
296+
show $ map prettyShow $ S.toList ilSet
297+
284298
isInlinedFunction :: QName -> C Bool
285-
isInlinedFunction q = (== InlinePragma) <$> getPragma q
299+
isInlinedFunction q = S.member q <$> getInlineSymbols
300+
301+
addInlineSymbols :: [QName] -> C ()
302+
addInlineSymbols qs = do
303+
reportSDoc "agda2hs.compile.inline" 15 $
304+
"Adding inline rules for" <+> pretty qs
305+
ilSetRef <- asks $ inlineSymbols . globalEnv
306+
liftIO $ modifyIORef ilSetRef $ \s -> foldr S.insert s qs
307+
286308

287309
debugCompileToMap :: C ()
288310
debugCompileToMap = do

test/AllTests.agda

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ import Issue409
101101
import Issue346
102102
import Issue408
103103
import CompileTo
104+
import RuntimeCast
104105

105106
{-# FOREIGN AGDA2HS
106107
import Issue14
@@ -199,4 +200,5 @@ import Issue409
199200
import Issue346
200201
import Issue408
201202
import CompileTo
203+
import RuntimeCast
202204
#-}

test/EraseType.agda

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ testCong = singCong (1 +_) testSingleton
2929
{-# COMPILE AGDA2HS testCong #-}
3030

3131
rTail : {@0 x xs} Singleton {a = List Int} (x ∷ xs) Singleton xs
32-
rTail = singTail
32+
rTail ys = singTail ys
3333

3434
{-# COMPILE AGDA2HS rTail #-}

test/Fail/Inline.agda

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,8 @@ tail' : List a → List a
66
tail' (x ∷ xs) = xs
77
tail' [] = []
88
{-# COMPILE AGDA2HS tail' inline #-}
9+
10+
test : List a List a
11+
test = tail'
12+
13+
{-# COMPILE AGDA2HS test #-}

test/Fail/Inline2.agda

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,8 @@ open import Haskell.Prelude
55
tail' : (xs : List a) @0 {{ NonEmpty xs }} List a
66
tail' (x ∷ xs) = xs
77
{-# COMPILE AGDA2HS tail' inline #-}
8+
9+
test : (xs : List a) @0 {{ NonEmpty xs }} List a
10+
test = tail'
11+
12+
{-# COMPILE AGDA2HS test #-}

test/Inlining.agda

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ test2 x y = mapWrap2 _+_ x y
3434
{-# COMPILE AGDA2HS test2 #-}
3535

3636
-- partial application of inline function
37-
test3 : Wrap Int Wrap Int Wrap Int
38-
test3 x = mapWrap2 _+_ x
39-
{-# COMPILE AGDA2HS test3 #-}
40-
41-
test4 : Wrap Int Wrap Int Wrap Int
42-
test4 = mapWrap2 _+_
43-
{-# COMPILE AGDA2HS test4 #-}
37+
-- test3 : Wrap Int → Wrap Int → Wrap Int
38+
-- test3 x = mapWrap2 _+_ x
39+
-- {-# COMPILE AGDA2HS test3 #-}
40+
--
41+
-- test4 : Wrap Int → Wrap Int → Wrap Int
42+
-- test4 = mapWrap2 _+_
43+
-- {-# COMPILE AGDA2HS test4 #-}

0 commit comments

Comments
 (0)