Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Educational: built-in map primitive #1093

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,14 @@ typeCheckPrimOp op = case op of

typeCheckPrimHof :: Typer m => PrimHof (Atom i) -> m i o (Type o)
typeCheckPrimHof hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of
Map f -> getTypeE f
Map fun array -> do
Pi (PiType (PiBinder b argTy PlainArrow) Pure resEltTy) <- getTypeE fun
let resEltTy' = ignoreHoistFailure $ hoist b resEltTy
TabPi (TabPiType binder argEltTy) <- getTypeE array
let argEltTy' = ignoreHoistFailure $ hoist binder argEltTy
checkAlphaEq argTy argEltTy'
refreshAbs (Abs binder UnitE) \binder' _ ->
return $ TabPi $ TabPiType binder' (sink resEltTy')
For _ ixDict f -> do
ixTy <- ixTyFromDict =<< substM ixDict
Pi (PiType (PiBinder b argTy PlainArrow) eff eltTy) <- getTypeE f
Expand Down
16 changes: 8 additions & 8 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -560,17 +560,17 @@ toImpHof :: Emits o => Maybe (Dest o) -> PrimHof (Atom i) -> SubstImpM i o (Atom
toImpHof maybeDest hof = do
resultTy <- getTypeSubst (Hof hof)
case hof of
Map (TabLam (TabLamExpr (b:>ixTy) body)) -> do
-- TODO: The following code block is identical to the `For` case below.
-- Reuse the code for the `For` case by generating `For (Lam ...)`, with
-- suitable `Lam ...`, when currently a `Map (TabLam ...)` is generated.
Map (Lam (LamExpr b body)) array -> do
rDest <- allocDest maybeDest resultTy
ixTy' <- substM ixTy
n <- indexSetSizeImp ixTy'
TabPi (TabPiType (_:>ixTy) _) <- getTypeSubst array
array' <- substM array
n <- indexSetSizeImp ixTy
emitLoop noHint Fwd n \i -> do
idx <- unsafeFromOrdinalImp (sink ixTy') i
idx <- unsafeFromOrdinalImp (sink ixTy) i
ithArg <- dropSubst $ translateExpr Nothing $
TabApp (sink array') $ idx :| []
ithDest <- destGet (sink rDest) idx
void $ extendSubst (b @> SubstVal idx) $
void $ extendSubst (b @> SubstVal ithArg) $
translateBlock (Just ithDest) body
destToAtom rDest
For d ixDict (Lam (LamExpr b body)) -> do
Expand Down
92 changes: 30 additions & 62 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,20 @@ getImplicitArg (PiBinder _ argTy arr) = case arr of
return $ Just $ Con $ DictHole (AlwaysEqual ctx) argTy
_ -> return Nothing

etaExpand :: EmitsInf n => Atom n -> InfererM i n (Atom n)
etaExpand fun = do
ty <- getType fun
case ty of
Pi (PiType (PiBinder b argTy arr) eff _) -> do
case fun of
Lam _ -> pure fun
_ -> buildLamInf noHint arr argTy
(\b' -> applySubst (b @> b') eff)
(\x -> do
Distinct <- getDistinct
app (sink fun) (Var x))
_ -> error "atom must have pi type"

checkOrInferRho :: forall i o. EmitsBoth o
=> UExpr i -> RequiredTy RhoType o -> InfererM i o (Atom o)
checkOrInferRho (WithSrcE pos expr) reqTy = do
Expand All @@ -944,68 +958,22 @@ checkOrInferRho (WithSrcE pos expr) reqTy = do
ixTy <- asIxType ty'
matchRequirement $ TabLam $ TabLamExpr (b':>ixTy) body'
UMap fun array -> do
argElemVar <- liftM Var $ freshInferenceName (TC TypeKind)
resElemVar <- liftM Var $ freshInferenceName (TC TypeKind)
funTy <- naryNonDepPiType PlainArrow Pure [argElemVar] resElemVar
fun' <- checkOrInferRho fun (Check funTy)

arrayReqTy <- case reqTy of
Check (TabPi (TabPiType (b:>ixTy) resElemTy)) -> do
-- TODO: Throw a graceful error if `resElemTy` depends on `b`.
let resElemTy' = ignoreHoistFailure $ hoist b resElemTy
constrainEq resElemVar resElemTy'
liftM (Check . TabPi) $ nonDepTabPiType ixTy argElemVar
Check _ -> return Infer
Infer -> return Infer
array' <- checkOrInferRho array arrayReqTy

-- Construct the `TabLam` for `Map`. This should probably not be done here
-- alongside the inference code; perhaps `AbstractSyntax.hs` would be a
-- better place for this. However, if we replaced replaced the `fun` and
-- `array` arguments to `UMap` with a `UTabLam`, this would make typing
-- failures less informative at the source code level (since the source code
-- contains concrete syntax for `fun` and `array`, but not for the
-- constructed `UTabLam`.) The cleanest alternative, however, would probably
-- be to place the construction of the `TabLam` (or of an equivalent block
-- paired with binders for its free variables) in `Lower.hs` or `Imp.hs`.
-- However, at that point we would require the expressions in the block
-- (including the expressions in the decls inside the block) to be
-- simplified; and the `TabApp` and `App` expressions below are apparently
-- not simplified.
TabPi (TabPiType (bA:>ixTy) argElemTy) <- getType array'
-- TODO: Throw a graceful error if `argElemTy` depends on `bA`.
let argElemTy' = ignoreHoistFailure $ hoist bA argElemTy
-- NOTE: In the definition of `arrayReqTy` we have already introduced a
-- constaint for `resElemVar`; but we still need to constrain `argElemVar`.
-- (Additionally `resElemVar` should also be constrained by the
-- `matchRequirement` below, but this is not the case for `argElemVar`.)
constrainEq argElemVar argElemTy'
Pi (PiType (PiBinder bF _ _) _ resElemTy) <- getType fun'
-- TODO: Throw a graceful error if `resElemTy` depends on `bF`.
let resElemTy' = ignoreHoistFailure $ hoist bF resElemTy

f <- withFreshBinder noHint ixTy \b0 -> do
let binder = b0:>ixTy
-- I am having trouble getting the following to work when trying to use a
-- single call to `withFreshBinders` (plural!) only. The problem appears
-- to be that `withFreshBinders` does not make available evidence of
-- `Distinct` for the intermediate scope, i.e. the scope that has `b1` but
-- not `b2`. Without this evidence, `sink fun'` in the let-block below is
-- not valid.
body <- withFreshBinder noHint (sink argElemTy') \b1 ->
withFreshBinder noHint (sink resElemTy') \b2 ->
let indexName = binderName b0
argElem = TabApp (sink array') $ (Var indexName) :| []
declArgElem = Let b1 (DeclBinding PlainLet (sink argElemTy') argElem)
funApp = App (sink fun') $ (Var $ binderName b1) :| []
declResElem = Let b2 (DeclBinding PlainLet (sink resElemTy') funApp)
ann = BlockAnn (sink resElemTy') Pure
block = Block ann (Nest declArgElem (Nest declResElem Empty)) (Var $ binderName b2)
in return block
return $ TabLam $ TabLamExpr binder body

result <- liftM Var $ emit $ Hof $ Map f
matchRequirement result
array' <- inferRho array
arrayTy <- getType array'
case arrayTy of
TabPi (TabPiType (b:>_) argElemTy) -> do
argElemTy' <- case hoist b argElemTy of
HoistSuccess ty -> return ty
HoistFailure _ -> throw TypeErr "expected non-dependent array type"
resElemVar <- liftM Var $ freshInferenceName (TC TypeKind)
funTy <- naryNonDepPiType PlainArrow Pure [argElemTy'] resElemVar
fun' <- checkOrInferRho fun (Check funTy)
-- Eta-expand `fun'` into a `Lam`. Later on we make use of the invariant
-- that the first argument of `Map` is a `Lam`.
fun'' <- etaExpand fun'
result <- liftM Var $ emit $ Hof $ Map fun'' array'
matchRequirement result
_ -> throw TypeErr "expected array type"
UFor dir (UForExpr b body) -> do
allowedEff <- getAllowedEffects
let uLamExpr = ULamExpr PlainArrow b body
Expand Down
9 changes: 7 additions & 2 deletions src/lib/QueryType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,12 @@ getTypePrimHof hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of
Pi (PiType (PiBinder b _ _) _ eltTy) <- getTypeE f
ixTy <- ixTyFromDict =<< substM dict
return $ TabTy (b:>ixTy) eltTy
Map f -> getTypeE f
Map fun array -> do
Pi (PiType (PiBinder b _ _) _ resEltTy) <- getTypeE fun
let resEltTy' = ignoreHoistFailure $ hoist b resEltTy
TabPi (TabPiType binder _) <- getTypeE array
refreshAbs (Abs binder UnitE) \binder' _ ->
return $ TabPi $ TabPiType binder' (sink resEltTy')
While _ -> return UnitTy
Linearize f -> do
Pi (PiType (PiBinder binder a PlainArrow) Pure b) <- getTypeE f
Expand Down Expand Up @@ -798,7 +803,7 @@ exprEffects expr = case expr of
_ -> return Pure
Hof hof -> case hof of
For _ _ f -> functionEffs f
Map _ -> return Pure
Map _ _ -> return Pure
While body -> functionEffs body
Linearize _ -> return Pure -- Body has to be a pure function
Transpose _ -> return Pure -- Body has to be a pure function
Expand Down
16 changes: 13 additions & 3 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -835,9 +835,19 @@ projectDictMethod d i = do

simplifyHof :: Emits o => Hof i -> SimplifyM i o (Atom o)
simplifyHof hof = case hof of
Map f -> do
f' <- simplifyAtom f
liftM Var $ emit $ Hof $ Map f'
Map fun array -> do
(fun', Abs b recon) <- simplifyLam fun
array' <- simplifyAtom array
ans <- liftM Var $ emit $ Hof $ Map fun' array'
case recon of
IdentityRecon -> return ans
LamRecon reconAbs -> do
TabPi (TabPiType (_:>ixTy) _) <- getType array'
buildTabLam noHint ixTy \i -> do
locals <- tabApp (sink ans) $ Var i
ithArg <- emitAtomToName =<< (tabApp (sink array') $ Var i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of emitting to name, you could b @> SubstVal ithArg below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

reconAbs' <- applySubst (b @> ithArg) reconAbs
applyReconAbs reconAbs' locals
For d ixDict lam -> do
ixTy@(IxType _ ixDict') <- ixTyFromDict =<< substM ixDict
(lam', Abs b recon) <- simplifyLam lam
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Types/Primitives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ traversePrimOp = inline traverse

data PrimHof e =
For ForAnn e e -- ix dict, body lambda
| Map e -- body tab-lambda
| Map e e -- lambda, array
| While e
| RunReader e e
| RunWriter (Maybe e) (BaseMonoidP e) e
Expand Down