Skip to content

Commit

Permalink
Hoist Hof out of PrimOp and remove Lam case from generic op.
Browse files Browse the repository at this point in the history
This takes advantage of things becoming more first-order to reduce boilerplate.
  • Loading branch information
dougalm committed May 10, 2024
1 parent 69b31a8 commit cf06eb5
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 117 deletions.
6 changes: 3 additions & 3 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ visitAlt (Abs b body) = do
traverseOpTerm
:: (GenericOp e, Visitor m r i o, OpConst e r ~ OpConst e r)
=> e r i -> m (e r o)
traverseOpTerm e = traverseOp e visitGeneric visitGeneric visitGeneric
traverseOpTerm e = traverseOp e visitGeneric visitGeneric

visitTypeDefault
:: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m)
Expand Down Expand Up @@ -397,6 +397,7 @@ instance IRRep r => VisitGeneric (Expr r) r where
ApplyMethod et m i xs -> ApplyMethod <$> visitGeneric et <*> visitGeneric m <*> pure i <*> mapM visitGeneric xs
Project t i x -> Project <$> visitGeneric t <*> pure i <*> visitGeneric x
Unwrap t x -> Unwrap <$> visitGeneric t <*> visitGeneric x
Hof op -> Hof <$> visitGeneric op

instance IRRep r => VisitGeneric (PrimOp r) r where
visitGeneric = \case
Expand All @@ -405,8 +406,7 @@ instance IRRep r => VisitGeneric (PrimOp r) r where
MemOp op -> MemOp <$> visitGeneric op
VectorOp op -> VectorOp <$> visitGeneric op
MiscOp op -> MiscOp <$> visitGeneric op
Hof op -> Hof <$> visitGeneric op
RefOp r op -> RefOp <$> visitGeneric r <*> traverseOp op visitGeneric visitGeneric visitGeneric
RefOp r op -> RefOp <$> visitGeneric r <*> traverseOp op visitGeneric visitGeneric

instance IRRep r => VisitGeneric (TypedHof r) r where
visitGeneric (TypedHof eff hof) = TypedHof <$> visitGeneric eff <*> visitGeneric hof
Expand Down
8 changes: 4 additions & 4 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ instance IRRep r => CheckableE r (Expr r) where
resultTy'' <- snd <$> unwrapNewtypeType con
checkTypesEq resultTy' resultTy''
return $ Unwrap resultTy' x'
Hof (TypedHof effTy hof) -> do
effTy' <- checkE effTy
hof' <- checkHof effTy' hof
return $ Hof (TypedHof effTy' hof')

instance CheckableE CoreIR TyConParams where
checkE (TyConParams expls params) = TyConParams expls <$> mapM checkE params
Expand Down Expand Up @@ -441,10 +445,6 @@ instance CheckableE CoreIR NewtypeTyCon where

instance IRRep r => CheckableE r (PrimOp r) where
checkE = \case
Hof (TypedHof effTy hof) -> do
effTy' <- checkE effTy
hof' <- checkHof effTy' hof
return $ Hof (TypedHof effTy' hof')
VectorOp vOp -> VectorOp <$> checkE vOp
BinOp binop x y -> do
x' <- checkE x
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ liftLamExpr (TopLam d ty (LamExpr bs body)) f = liftM (TopLam d ty) $ liftEnvRea
fromNaryForExpr :: IRRep r => Int -> Expr r n -> Maybe (Int, LamExpr r n)
fromNaryForExpr maxDepth | maxDepth <= 0 = error "expected non-negative number of args"
fromNaryForExpr maxDepth = \case
PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr b body)))) ->
Hof (TypedHof _ (For _ _ (UnaryLamExpr b body))) ->
extend <|> (Just $ (1, LamExpr (Nest b Empty) body))
where
extend = do
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
TabApp _ _ _ -> error "Unexpected `TabApp` in Imp pass."
TabCon _ _ -> error "Unexpected `TabCon` in Imp pass."
Project _ i x -> reduceProj i =<< substM x
Hof hof -> toImpTypedHof hof

toImpRefOp :: Emits o
=> SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o)
Expand All @@ -336,7 +337,6 @@ toImpRefOp refDest' m = do

toImpOp :: forall i o . Emits o => PrimOp SimpIR i -> SubstImpM i o (SAtom o)
toImpOp op = case op of
Hof hof -> toImpTypedHof hof
RefOp refDest eff -> toImpRefOp refDest eff
BinOp binOp x y -> returnIExprVal =<< emitInstr =<< (IBinOp binOp <$> fsa x <*> fsa y)
UnOp unOp x -> returnIExprVal =<< emitInstr =<< (IUnOp unOp <$> fsa x)
Expand Down
6 changes: 5 additions & 1 deletion src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,11 @@ matchPrimApp = \case
Just x' <- return $ toMaybeType x
return $ Left x'
_ -> return $ Right x
return $ fromJust $ toOp $ GenericOpRep op tyArgs dataArgs []
let tyArgs' = case tyArgs of
[] -> Nothing
[t] -> Just t
_ -> error "Expected at most one type arg"
return $ fromJust $ toOp $ GenericOpRep op tyArgs' dataArgs

pattern ExplicitCoreLam :: Nest CBinder n l -> CExpr l -> CAtom n
pattern ExplicitCoreLam bs body <- Con (Lam (CoreLamExpr _ (LamExpr bs body)))
Expand Down
4 changes: 2 additions & 2 deletions src/lib/Inline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ inlineDeclsSubst = \case
-- since their main purpose is to force inlining in the simplifier, and if
-- one just stuck like this it has become equivalent to a `for` anyway.
ixDepthExpr :: Expr SimpIR n -> Int
ixDepthExpr (PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr _ body))))) = 1 + ixDepthExpr body
ixDepthExpr (Hof (TypedHof _ (For _ _ (UnaryLamExpr _ body)))) = 1 + ixDepthExpr body
ixDepthExpr _ = 0

-- Should we decide to inline this binding wherever it appears, before we even
Expand Down Expand Up @@ -316,7 +316,7 @@ reconstruct ctx e = case ctx of
reconstructTabApp :: Emits o
=> Context SExpr e o -> SExpr o -> SAtom i -> InlineM i o (e o)
reconstructTabApp ctx expr i = case expr of
PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr b body)))) -> do
Hof (TypedHof _ (For _ _ (UnaryLamExpr b body))) -> do
-- See NoteReconstructTabAppDecisions
AtomVar i' _ <- inline (EmitToNameCtx Stop) i
dropSubst $ extendSubst (b@>Rename i') do
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Linearize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,10 @@ linearizeExpr expr = case expr of
Project _ i x -> do
x' <- linearizeAtom x
emitBoth x' \x'' -> mkProject i x''
Hof (TypedHof _ e) -> linearizeHof e

linearizeOp :: Emits o => PrimOp SimpIR i -> LinM i o SAtom SAtom
linearizeOp op = case op of
Hof (TypedHof _ e) -> linearizeHof e
RefOp ref m -> do
ref' <- linearizeAtom ref
case m of
Expand Down
2 changes: 1 addition & 1 deletion src/lib/OccAnalysis.hs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ instance HasOCC SExpr where
ty' <- occTy ty
countFreeVarsAsOccurrences effs
return $ Case scrut' alts' (EffTy effs ty')
PrimOp (Hof op) -> PrimOp . Hof <$> occ a op
Hof op -> Hof <$> occ a op
PrimOp (RefOp ref op) -> do
ref' <- occ a ref
PrimOp . RefOp ref' <$> occ a op
Expand Down
4 changes: 2 additions & 2 deletions src/lib/Optimize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ instance ExprVisitorEmits (ULM i o) SimpIR i o where
-- constant-foldable after inlining don't count towards it.
ulExpr :: Emits o => SExpr i -> ULM i o (SAtom o)
ulExpr expr = case expr of
PrimOp (Hof (TypedHof _ (For Fwd ixTy body))) ->
Hof (TypedHof _ (For Fwd ixTy body)) ->
case ixTypeDict ixTy of
DictCon (IxRawFin (IdxRepVal n)) -> do
(body', bodyCost) <- withLocalAccounting $ visitLamEmits body
Expand Down Expand Up @@ -133,7 +133,7 @@ hoistLoopInvariant lam = liftLamExpr lam hoistLoopInvariantExpr

licmExpr :: Emits o => SExpr i -> LICMM i o (SAtom o)
licmExpr = \case
PrimOp (Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body)))) -> undefined
Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body))) -> undefined
-- ix' <- substM ix
-- Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do
-- Abs decls ans <- buildScoped $ visitExprEmits body
Expand Down
4 changes: 2 additions & 2 deletions src/lib/QueryTypePure.hs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ instance IRRep r => HasType r (Expr r) where
ApplyMethod (EffTy _ t) _ _ _ -> t
Project t _ _ -> t
Unwrap t _ -> t
Hof (TypedHof (EffTy _ ty) _) -> ty

instance IRRep r => HasType r (RepVal r) where
getType (RepVal ty _) = ty
Expand All @@ -148,7 +149,6 @@ instance IRRep r => HasType r (PrimOp r) where
getType primOp = case primOp of
BinOp op x _ -> TyCon $ BaseType $ typeBinOp op $ getTypeBaseType x
UnOp op x -> TyCon $ BaseType $ typeUnOp op $ getTypeBaseType x
Hof (TypedHof (EffTy _ ty) _) -> ty
MemOp op -> getType op
MiscOp op -> getType op
VectorOp op -> getType op
Expand Down Expand Up @@ -258,6 +258,7 @@ instance IRRep r => HasEffects (Expr r) r where
PrimOp primOp -> getEffects primOp
Project _ _ _ -> Pure
Unwrap _ _ -> Pure
Hof (TypedHof (EffTy eff _) _) -> eff

instance IRRep r => HasEffects (DeclBinding r) r where
getEffects (DeclBinding _ expr) = getEffects expr
Expand Down Expand Up @@ -291,5 +292,4 @@ instance IRRep r => HasEffects (PrimOp r) r where
MPut _ -> Effectful
IndexRef _ _ -> Pure
ProjRef _ _ -> Pure
Hof (TypedHof (EffTy eff _) _) -> eff
{-# INLINE getEffects #-}
4 changes: 2 additions & 2 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ simplifyExpr = \case
simplifyTabApp f' x'
Atom x -> simplifyAtom x
PrimOp op -> simplifyOp op
Hof (TypedHof (EffTy _ ty) hof) -> simplifyHof hof
ApplyMethod (EffTy _ ty) dict i xs -> do
xs' <- mapM simplifyAtom xs
SimpCCon (WithSubst s (DictConAtom d)) <- simplifyAtom dict
Expand Down Expand Up @@ -408,7 +409,6 @@ simplifyLam (LamExpr bsTop body) = case bsTop of

simplifyOp :: Emits o => PrimOp CoreIR i -> SimplifyM i o (SimpVal o)
simplifyOp op = case op of
Hof (TypedHof (EffTy _ ty) hof) -> simplifyHof hof
MemOp op' -> simplifyGenericOp op'
VectorOp op' -> simplifyGenericOp op'
RefOp ref eff -> do
Expand All @@ -433,7 +433,7 @@ simplifyGenericOp
=> op CoreIR i
-> SimplifyM i o (SimpVal o)
simplifyGenericOp op = do
op' <- traverseOp op getRepType toDataAtom (error "shouldn't have lambda left")
op' <- traverseOp op getRepType toDataAtom
SimpAtom <$> emit op'
{-# INLINE simplifyGenericOp #-}

Expand Down
2 changes: 1 addition & 1 deletion src/lib/Transpose.hs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ transposeExpr expr ct = case expr of
forM_ (enumerate es) \(ordinalIdx, e) -> do
i <- unsafeFromOrdinal idxTy (IdxRepVal $ fromIntegral ordinalIdx)
tabApp ct i >>= transposeAtom e
Hof (TypedHof _ hof) -> transposeHof hof ct
TabApp _ _ _ -> error "should have been handled by reference projection"
Project _ _ _ -> error "should have been handled by reference projection"

Expand All @@ -193,7 +194,6 @@ transposeOp op ct = case op of
void $ emitEff $ MPut zero
IndexRef _ _ -> notImplemented
ProjRef _ _ -> notImplemented
Hof (TypedHof _ hof) -> transposeHof hof ct
MiscOp miscOp -> transposeMiscOp miscOp ct
UnOp FNeg x -> transposeAtom x =<< (emitLin $ UnOp FNeg ct)
UnOp _ _ -> notLinear
Expand Down
Loading

0 comments on commit cf06eb5

Please sign in to comment.