diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index 99c922a06..428fa40cc 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -494,6 +494,10 @@ expr = propagateSrcE expr' where UApp (mkApp (ns $ fromString rangeName) (ns UHole)) lim expr' (CLambda args body) = dropSrcE <$> liftM2 buildLam (concat <$> mapM argument args) (block body) + expr' (CMap fun array) = do + fun' <- expr fun + array' <- expr array + return $ UMap fun' array' expr' (CFor KView indices body) = dropSrcE <$> (buildTabLam <$> mapM patOptAnn indices <*> block body) expr' (CFor kind indices body) = do diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index d72d1712a..5f4c0aca8 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -777,6 +777,14 @@ typeCheckPrimOp op = case op of typeCheckPrimHof :: Typer m => PrimHof (Atom i) -> m i o (Type o) typeCheckPrimHof hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of + Map fun array -> do + Pi (PiType (PiBinder b argTy PlainArrow) Pure resEltTy) <- getTypeE fun + let resEltTy' = ignoreHoistFailure $ hoist b resEltTy + TabPi (TabPiType binder argEltTy) <- getTypeE array + let argEltTy' = ignoreHoistFailure $ hoist binder argEltTy + checkAlphaEq argTy argEltTy' + refreshAbs (Abs binder UnitE) \binder' _ -> + return $ TabPi $ TabPiType binder' (sink resEltTy') For _ ixDict f -> do ixTy <- ixTyFromDict =<< substM ixDict Pi (PiType (PiBinder b argTy PlainArrow) eff eltTy) <- getTypeE f diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 446ad5d53..60cef64d7 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -114,6 +114,7 @@ data Group' | CPrefix SourceName Group -- covers unary - and unary + among others | CPostfix SourceName Group | CLambda [Group] CBlock -- The arguments do not have Juxtapose at the top level + | CMap Group Group -- unary fun, array | CFor ForKind [Group] CBlock -- also for_, rof, rof_, view | CCase Group [(Group, CBlock)] -- scrutinee, alternatives | CIf Group CBlock (Maybe CBlock) @@ -559,6 +560,14 @@ cLam = do body <- cBlock return $ CLambda bs body +cMap :: Parser Group' +cMap = do + keyWord MapKW + fun <- cGroupNoJuxt + keyWord OverKW + array <- cGroup + return $ CMap fun array + cFor :: Parser Group' cFor = do kw <- forKW @@ -704,6 +713,7 @@ leafGroupNoBrackets = do _ | isDigit next -> ( CNat <$> natLit <|> CFloat <$> doubleLit) '\\' -> cLam + 'm' -> cMap <|> CIdentifier <$> anyName -- For exprs include view, for, rof, for_, rof_ 'v' -> cFor <|> CIdentifier <$> anyName 'f' -> cFor <|> CIdentifier <$> anyName diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index d3c0fbcac..82c5831f5 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -560,6 +560,19 @@ toImpHof :: Emits o => Maybe (Dest o) -> PrimHof (Atom i) -> SubstImpM i o (Atom toImpHof maybeDest hof = do resultTy <- getTypeSubst (Hof hof) case hof of + Map (Lam (LamExpr b body)) array -> do + rDest <- allocDest maybeDest resultTy + TabPi (TabPiType (_:>ixTy) _) <- getTypeSubst array + array' <- substM array + n <- indexSetSizeImp ixTy + emitLoop noHint Fwd n \i -> do + idx <- unsafeFromOrdinalImp (sink ixTy) i + ithArg <- dropSubst $ translateExpr Nothing $ + TabApp (sink array') $ idx :| [] + ithDest <- destGet (sink rDest) idx + void $ extendSubst (b @> SubstVal ithArg) $ + translateBlock (Just ithDest) body + destToAtom rDest For d ixDict (Lam (LamExpr b body)) -> do ixTy <- ixTyFromDict =<< substM ixDict n <- indexSetSizeImp ixTy diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 028683a6d..2d0f671ba 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -918,6 +918,20 @@ getImplicitArg (PiBinder _ argTy arr) = case arr of return $ Just $ Con $ DictHole (AlwaysEqual ctx) argTy _ -> return Nothing +etaExpand :: EmitsInf n => Atom n -> InfererM i n (Atom n) +etaExpand fun = do + ty <- getType fun + case ty of + Pi (PiType (PiBinder b argTy arr) eff _) -> do + case fun of + Lam _ -> pure fun + _ -> buildLamInf noHint arr argTy + (\b' -> applySubst (b @> b') eff) + (\x -> do + Distinct <- getDistinct + app (sink fun) (Var x)) + _ -> error "atom must have pi type" + checkOrInferRho :: forall i o. EmitsBoth o => UExpr i -> RequiredTy RhoType o -> InfererM i o (Atom o) checkOrInferRho (WithSrcE pos expr) reqTy = do @@ -943,6 +957,23 @@ checkOrInferRho (WithSrcE pos expr) reqTy = do Infer -> inferULam Pure uLamExpr ixTy <- asIxType ty' matchRequirement $ TabLam $ TabLamExpr (b':>ixTy) body' + UMap fun array -> do + array' <- inferRho array + arrayTy <- getType array' + case arrayTy of + TabPi (TabPiType (b:>_) argElemTy) -> do + argElemTy' <- case hoist b argElemTy of + HoistSuccess ty -> return ty + HoistFailure _ -> throw TypeErr "expected non-dependent array type" + resElemVar <- liftM Var $ freshInferenceName (TC TypeKind) + funTy <- naryNonDepPiType PlainArrow Pure [argElemTy'] resElemVar + fun' <- checkOrInferRho fun (Check funTy) + -- Eta-expand `fun'` into a `Lam`. Later on we make use of the invariant + -- that the first argument of `Map` is a `Lam`. + fun'' <- etaExpand fun' + result <- liftM Var $ emit $ Hof $ Map fun'' array' + matchRequirement result + _ -> throw TypeErr "expected array type" UFor dir (UForExpr b body) -> do allowedEff <- getAllowedEffects let uLamExpr = ULamExpr PlainArrow b body diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index bf2093f87..2d2b52a28 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -69,11 +69,14 @@ data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW | ViewKW | ImportKW | ForeignKW | NamedInstanceKW | EffectKW | HandlerKW | JmpKW | CtlKW | ReturnKW | ResumeKW | CustomLinearizationKW | CustomLinearizationSymbolicKW + | MapKW | OverKW deriving (Enum) keyWordToken :: KeyWord -> String keyWordToken = \case DefKW -> "def" + MapKW -> "map_" + OverKW -> "over_" ForKW -> "for" RofKW -> "rof" For_KW -> "for_" diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index d3d6e1654..34c6843ec 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -662,6 +662,8 @@ instance PrettyPrec (UExpr' n) where <+> nest 2 (pLowest body) where kw = case dir of Fwd -> "for" Rev -> "rof" + UMap fun array -> atPrec LowestPrec $ "map_" <+> nest 2 (pLowest fun) + <+> "over_" <+> nest 2 (pLowest array) UPi piType -> prettyPrec piType UTabPi piType -> prettyPrec piType UDecl declExpr -> prettyPrec declExpr diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 303835a0b..4e5c6febf 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -698,6 +698,12 @@ getTypePrimHof hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of Pi (PiType (PiBinder b _ _) _ eltTy) <- getTypeE f ixTy <- ixTyFromDict =<< substM dict return $ TabTy (b:>ixTy) eltTy + Map fun array -> do + Pi (PiType (PiBinder b _ _) _ resEltTy) <- getTypeE fun + let resEltTy' = ignoreHoistFailure $ hoist b resEltTy + TabPi (TabPiType binder _) <- getTypeE array + refreshAbs (Abs binder UnitE) \binder' _ -> + return $ TabPi $ TabPiType binder' (sink resEltTy') While _ -> return UnitTy Linearize f -> do Pi (PiType (PiBinder binder a PlainArrow) Pure b) <- getTypeE f @@ -797,6 +803,7 @@ exprEffects expr = case expr of _ -> return Pure Hof hof -> case hof of For _ _ f -> functionEffs f + Map _ _ -> return Pure While body -> functionEffs body Linearize _ -> return Pure -- Body has to be a pure function Transpose _ -> return Pure -- Body has to be a pure function diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 3bf96a4c3..d59ee255d 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -835,6 +835,19 @@ projectDictMethod d i = do simplifyHof :: Emits o => Hof i -> SimplifyM i o (Atom o) simplifyHof hof = case hof of + Map fun array -> do + (fun', Abs b recon) <- simplifyLam fun + array' <- simplifyAtom array + ans <- liftM Var $ emit $ Hof $ Map fun' array' + case recon of + IdentityRecon -> return ans + LamRecon reconAbs -> do + TabPi (TabPiType (_:>ixTy) _) <- getType array' + buildTabLam noHint ixTy \i -> do + locals <- tabApp (sink ans) $ Var i + ithArg <- emitAtomToName =<< (tabApp (sink array') $ Var i) + reconAbs' <- applySubst (b @> ithArg) reconAbs + applyReconAbs reconAbs' locals For d ixDict lam -> do ixTy@(IxType _ ixDict') <- ixTyFromDict =<< substM ixDict (lam', Abs b recon) <- simplifyLam lam diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index b7a465fff..7dde587c6 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -207,6 +207,8 @@ instance SourceRenamableE UExpr' where UDecl (UDeclExpr decl rest) -> sourceRenameB decl \decl' -> UDecl <$> UDeclExpr decl' <$> sourceRenameE rest + UMap fun array -> UMap <$> sourceRenameE fun + <*> sourceRenameE array UFor d (UForExpr pat body) -> sourceRenameB pat \pat' -> UFor d <$> UForExpr pat' <$> sourceRenameE body diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index fada8281a..948bb82c9 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -158,6 +158,7 @@ traversePrimOp = inline traverse data PrimHof e = For ForAnn e e -- ix dict, body lambda + | Map e e -- lambda, array | While e | RunReader e e | RunWriter (Maybe e) (BaseMonoidP e) e diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index d18b6a454..f42fc581a 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -93,6 +93,7 @@ data UExpr' (n::S) = | UTabPi (UTabPiExpr n) | UTabApp (UExpr n) (UExpr n) | UDecl (UDeclExpr n) + | UMap (UExpr n) (UExpr n) | UFor Direction (UForExpr n) | UCase (UExpr n) [UAlt n] | UHole diff --git a/tests/mymap.dx b/tests/mymap.dx new file mode 100644 index 000000000..d5fc48c64 --- /dev/null +++ b/tests/mymap.dx @@ -0,0 +1,22 @@ +def my_map {a:Type} {b:Type} {n:Type} [Ix n] + (f:a -> b) + (x:n => a) : n => b = + for i:n. f x.i + +x0 = [1, 2, 3, 4, 5] + +my_map (\x. x+x) x0 +map_ (\x. x+x) over_ x0 + +my_map (\x. 2*x) x0 +map_ (\x. 2*x) over_ x0 + +my_map (\x. 2*x) (x0 + x0) +map_ (\x. 2*x) over_ (x0 + x0) +-- The following is also parsed as `map_ (\x. 2*x) over_ (x0 + x0)` ... not +-- intentionally so. +map_ (\x. 2*x) over_ x0 + x0 + +my_map (\x. 2*x) x0 + x0 +(map_ (\x. 2*x) over_ x0) + x0 +