diff --git a/cabal.project b/cabal.project index 1e33d64aec..2a759b23fc 100644 --- a/cabal.project +++ b/cabal.project @@ -5,3 +5,8 @@ package futhark ghc-options: -j -fwrite-ide-info -hiedir=.hie allow-newer: base, template-haskell + +source-repository-package + type: git + location: https://github.com/jyp/glpk-hs.git + tag: 1f276aa19861203ea8367dc27a6ad4c8a31c9062 diff --git a/default.nix b/default.nix index 842e8937a4..4d27d336c9 100644 --- a/default.nix +++ b/default.nix @@ -37,6 +37,12 @@ let zlib = haskellPackagesNew.callPackage ./nix/zlib.nix {zlib=pkgs.zlib;}; + gasp = + haskellPackagesNew.callPackage ./nix/gasp.nix {}; + + glpk-hs = + haskellPackagesNew.callPackage ./nix/glpk-hs.nix {}; + futhark = # callCabal2Nix does not do a great job at determining # which files must be included as source, which causes @@ -75,6 +81,7 @@ let "--extra-lib-dirs=${pkgs.glibc.static}/lib" "--extra-lib-dirs=${pkgs.gmp6.override { withStatic = true; }}/lib" "--extra-lib-dirs=${pkgs.libffi.overrideAttrs (old: { dontDisableStatic = true; })}/lib" + "--extra-lib-dirs=${pkgs.glpk.overrideAttrs (old: { dontDisableStatic = true; })}/lib" # The ones below are due to GHC's runtime system # depending on libdw (DWARF info), which depends on # a bunch of compression algorithms. diff --git a/docs/language-reference.rst b/docs/language-reference.rst index ded7882545..91c53037d5 100644 --- a/docs/language-reference.rst +++ b/docs/language-reference.rst @@ -1002,9 +1002,11 @@ Syntactic sugar for ``let a = a with [i] = v in a``. ............................... Bind ``f`` to a function with the given parameters and definition -(``e``) and evaluate ``body``. The function will be treated as -aliasing any free variables in ``e``. The function is not in scope of -itself, and hence cannot be recursive. +(``e``) and evaluate ``body``. The function will be treated as +aliasing any free variables in ``e``. The function is not in scope of +itself, and hence cannot be recursive. While the function can be made +polymorphic by putting in explicit size parameters, it is not +automatically generalised the way top level functions are. ``loop pat = initial for x in a do loopbody`` ............................................. diff --git a/futhark.cabal b/futhark.cabal index c717050d96..31fb133563 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -379,6 +379,11 @@ library Futhark.Pkg.Types Futhark.Profile Futhark.Script + Futhark.Solve.GLPK + Futhark.Solve.LP + Futhark.Solve.Matrix + Futhark.Solve.Simplex + Futhark.Solve.BranchAndBound Futhark.Test Futhark.Test.Spec Futhark.Test.Values @@ -419,11 +424,15 @@ library Language.Futhark.Tuple Language.Futhark.TypeChecker Language.Futhark.TypeChecker.Consumption + Language.Futhark.TypeChecker.Constraints + Language.Futhark.TypeChecker.Error Language.Futhark.TypeChecker.Names Language.Futhark.TypeChecker.Match Language.Futhark.TypeChecker.Modules Language.Futhark.TypeChecker.Monad + Language.Futhark.TypeChecker.Rank Language.Futhark.TypeChecker.Terms + Language.Futhark.TypeChecker.Terms2 Language.Futhark.TypeChecker.Terms.Loop Language.Futhark.TypeChecker.Terms.Monad Language.Futhark.TypeChecker.Terms.Pat @@ -496,6 +505,9 @@ library , mwc-random , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 + -- remove me later + , glpk-hs + , silently executable futhark import: common @@ -531,6 +543,8 @@ test-suite unit Futhark.Optimise.ArrayLayoutTests Futhark.Pkg.SolveTests Futhark.ProfileTests + Futhark.Solve.BranchAndBoundTests + Futhark.Solve.SimplexTests Language.Futhark.CoreTests Language.Futhark.PrimitiveTests Language.Futhark.SemanticTests @@ -549,3 +563,4 @@ test-suite unit , tasty-hunit , tasty-quickcheck , text + , vector >=0.12 diff --git a/nix/gasp.nix b/nix/gasp.nix new file mode 100644 index 0000000000..526b047bdc --- /dev/null +++ b/nix/gasp.nix @@ -0,0 +1,14 @@ +{ mkDerivation, adjunctions, base, binary, constraints, containers +, distributive, lib, mtl, QuickCheck +}: +mkDerivation { + pname = "gasp"; + version = "1.4.0.0"; + sha256 = "9a73a6ea7eb844493deb76c85c50249915e5ca29a6734a0b133a0e136c232f9f"; + libraryHaskellDepends = [ + adjunctions base binary constraints containers distributive mtl + QuickCheck + ]; + description = "A framework of algebraic classes"; + license = lib.licenses.bsd3; +} diff --git a/nix/glpk-hs.nix b/nix/glpk-hs.nix new file mode 100644 index 0000000000..6f5a2b0081 --- /dev/null +++ b/nix/glpk-hs.nix @@ -0,0 +1,23 @@ +{ mkDerivation, array, base, containers, deepseq, fetchgit, gasp +, glpk, lib, mtl +}: +mkDerivation { + pname = "glpk-hs"; + version = "0.8"; + src = fetchgit { + url = "https://github.com/ludat/glpk-hs.git"; + sha256 = "0nly5nifdb93f739vr3jzgi16fccqw5l0aabf5lglsdkdad713q1"; + rev = "efcb8354daa1205de2b862898353da2e4beb76b2"; + fetchSubmodules = true; + }; + isLibrary = true; + isExecutable = true; + libraryHaskellDepends = [ array base containers deepseq gasp mtl ]; + librarySystemDepends = [ glpk ]; + executableHaskellDepends = [ + array base containers deepseq gasp mtl + ]; + description = "Comprehensive GLPK linear programming bindings"; + license = lib.licenses.bsd3; + mainProgram = "glpk-hs-example"; +} diff --git a/prelude/soacs.fut b/prelude/soacs.fut index 02576f09a9..c2cfb22a44 100644 --- a/prelude/soacs.fut +++ b/prelude/soacs.fut @@ -48,7 +48,7 @@ import "zip" -- -- **Span:** *O(S(f))* def map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = - intrinsics.map f as + f as -- | Apply the given function to each element of a single array. -- diff --git a/prelude/zip.fut b/prelude/zip.fut index fdd0abbfe5..5ccbacc17b 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -6,11 +6,6 @@ -- The main reason this module exists is that we need it to define -- SOACs like `map2`. --- We need a map to define some of the zip variants, but this file is --- depended upon by soacs.fut. So we just define a quick-and-dirty --- internal one here that uses the intrinsic version. -local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = - intrinsics.map f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = @@ -22,15 +17,15 @@ def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = -- | As `zip2`@term, but with one more array. def zip3 [n] 'a 'b 'c (as: [n]a) (bs: [n]b) (cs: [n]c) : *[n](a, b, c) = - internal_map (\(a, (b, c)) -> (a, b, c)) (zip as (zip2 bs cs)) + (\(a, (b, c)) -> (a, b, c)) (zip as (zip2 bs cs)) -- | As `zip3`@term, but with one more array. def zip4 [n] 'a 'b 'c 'd (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) : *[n](a, b, c, d) = - internal_map (\(a, (b, c, d)) -> (a, b, c, d)) (zip as (zip3 bs cs ds)) + (\(a, (b, c, d)) -> (a, b, c, d)) (zip as (zip3 bs cs ds)) -- | As `zip4`@term, but with one more array. def zip5 [n] 'a 'b 'c 'd 'e (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e) : *[n](a, b, c, d, e) = - internal_map (\(a, (b, c, d, e)) -> (a, b, c, d, e)) (zip as (zip4 bs cs ds es)) + (\(a, (b, c, d, e)) -> (a, b, c, d, e)) (zip as (zip4 bs cs ds es)) -- | Turn an array of pairs into two arrays. def unzip [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = @@ -42,18 +37,18 @@ def unzip2 [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = -- | As `unzip2`@term, but with one more array. def unzip3 [n] 'a 'b 'c (xs: [n](a, b, c)) : ([n]a, [n]b, [n]c) = - let (as, bcs) = unzip (internal_map (\(a, b, c) -> (a, (b, c))) xs) + let (as, bcs) = unzip ((\(a, b, c) -> (a, (b, c))) xs) let (bs, cs) = unzip bcs in (as, bs, cs) -- | As `unzip3`@term, but with one more array. def unzip4 [n] 'a 'b 'c 'd (xs: [n](a, b, c, d)) : ([n]a, [n]b, [n]c, [n]d) = - let (as, bs, cds) = unzip3 (internal_map (\(a, b, c, d) -> (a, b, (c, d))) xs) + let (as, bs, cds) = unzip3 ((\(a, b, c, d) -> (a, b, (c, d))) xs) let (cs, ds) = unzip cds in (as, bs, cs, ds) -- | As `unzip4`@term, but with one more array. def unzip5 [n] 'a 'b 'c 'd 'e (xs: [n](a, b, c, d, e)) : ([n]a, [n]b, [n]c, [n]d, [n]e) = - let (as, bs, cs, des) = unzip4 (internal_map (\(a, b, c, d, e) -> (a, b, c, (d, e))) xs) + let (as, bs, cs, des) = unzip4 ((\(a, b, c, d, e) -> (a, b, c, (d, e))) xs) let (ds, es) = unzip des in (as, bs, cs, ds, es) diff --git a/shell.nix b/shell.nix index 2ff2206b31..7c14d47936 100644 --- a/shell.nix +++ b/shell.nix @@ -2,7 +2,28 @@ let sources = import ./nix/sources.nix; pkgs = import sources.nixpkgs {}; - python = pkgs.python311Packages; + python = pkgs.python311.withPackages (ps: with ps; [ + ( + buildPythonPackage rec { + pname = "PuLP"; + version = "2.7.0"; + src = fetchPypi { + inherit pname version; + sha256 = "sha256-5z7msy1jnJuM9LSt7TNLoVi+X4MTVE4Fb3lqzgoQrmM="; + }; + doCheck = false; + } + ) + ps.mypy + black + cycler + numpy + pyopencl + matplotlib + jsonschema + sphinx + sphinxcontrib-bibtex + ]); haskell = pkgs.haskell.packages.ghc96; in pkgs.stdenv.mkDerivation { @@ -22,6 +43,7 @@ pkgs.stdenv.mkDerivation { haskell.haskell-language-server haskellPackages.graphmod haskellPackages.apply-refact + python xdot hlint pkg-config @@ -31,17 +53,8 @@ pkgs.stdenv.mkDerivation { ghcid niv ispc - python.python - python.mypy - python.black - python.cycler - python.numpy - python.pyopencl - python.matplotlib - python.jsonschema - python.sphinx - python.sphinxcontrib-bibtex imagemagick # needed for literate tests + glpk ] ++ lib.optionals (stdenv.isLinux) [ opencl-headers diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 82b7633e71..65f0375348 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -750,7 +750,7 @@ etaExpand e_t e = do M.fromList . zip (retDims ret) $ map (ExpSubst . flip sizeFromName mempty . qualName) ext' ret' = applySubst (`M.lookup` extsubst) ret - e' = mkApply e (map (Nothing,) vars) $ AppRes (toStruct $ retType ret') ext' + e' = mkApply e (map (\v -> (Nothing, mempty, v)) vars) $ AppRes (toStruct $ retType ret') ext' pure (params, e', ret) where getType (RetType _ (Scalar (Arrow _ p d t1 t2))) = @@ -856,7 +856,7 @@ unRetType (RetType ext t) = do defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal) defuncApplyFunction e@(Var qn (Info t) loc) num_args = do - let (argtypes, rettype) = unfoldFunType t + let (argtypes, rettype) = first (map snd) $ unfoldFunType t sv <- lookupVar (toStruct t) (qualLeaf qn) case sv of @@ -908,9 +908,9 @@ liftedName _ _ = "defunc" defuncApplyArg :: String -> (Exp, StaticVal) -> - ((Maybe VName, Exp), [ParamType]) -> + (((Maybe VName, AutoMap), Exp), [ParamType]) -> DefM (Exp, StaticVal) -defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) ((argext, arg), _) = do +defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do (arg', arg_sv) <- defuncExp arg let env' = alwaysMatchPatSV pat arg_sv dims = mempty @@ -961,18 +961,18 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) ((argext, ar callret <- unRetType lifted_rettype pure - ( mkApply fname' [(Nothing, f'), (argext, arg')] callret, + ( mkApply fname' [(Nothing, mempty, f'), (argext, mempty, arg')] callret, sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially -- applied or return a higher-order value. -defuncApplyArg _ (f', DynamicFun _ sv) ((argext, arg), argtypes) = do +defuncApplyArg _ (f', DynamicFun _ sv) (((argext, _), arg), argtypes) = do (arg', _) <- defuncExp arg let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] - apply_e = mkApply f' [(argext, arg')] callret + apply_e = mkApply f' [(argext, mempty, arg')] callret pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = @@ -989,7 +989,7 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e -defuncApply :: Exp -> NE.NonEmpty (Maybe VName, Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) +defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) case f_sv of @@ -1006,8 +1006,8 @@ defuncApply f args appres loc = do (argtypes, _) = unfoldFunType $ typeOf f fmap (first $ updateReturn appres) $ foldM (defuncApplyArg fname) (f', f_sv) $ - NE.zip args $ - NE.tails argtypes + NE.zip args . NE.tails . map snd $ + argtypes where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index b805ea777f..2c62dc2074 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -356,21 +356,6 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- Some functions are magical (overloaded) and we handle that here. case () of () - -- Short-circuiting operators are magical. - | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "&&", - [(x, _), (y, _)] <- args -> - internaliseExp desc $ - E.AppExp - (E.If x y (E.Literal (E.BoolValue False) mempty) mempty) - (Info $ AppRes (E.Scalar $ E.Prim E.Bool) []) - | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "||", - [(x, _), (y, _)] <- args -> - internaliseExp desc $ - E.AppExp - (E.If x (E.Literal (E.BoolValue True) mempty) y mempty) - (Info $ AppRes (E.Scalar $ E.Prim E.Bool) []) -- Overloaded and intrinsic functions never take array -- arguments (except equality, but those cannot be -- existential), so we can safely ignore the existential @@ -1499,7 +1484,7 @@ findFuncall (E.Apply f args _) | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where - onArg (Info argext, e) = (e, argext) + onArg (Info (argext, _), e) = (e, argext) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e @@ -1596,12 +1581,15 @@ isOverloadedFunction qname desc loc = do handle name | Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = Just $ \[(x_t, [x']), (y_t, [y'])] -> - case (x_t, y_t) of + case (arrayElem x_t, arrayElem y_t) of (E.Scalar (E.Prim t1), E.Scalar (E.Prim t2)) -> internaliseBinOp loc desc bop x' y' t1 t2 _ -> error "Futhark.Internalise.internaliseExp: non-primitive type in BinOp." handle _ = Nothing + arrayElem (E.Array _ _ t) = E.Scalar t + arrayElem t = t + -- | Handle intrinsic functions. These are only allowed to be called -- in the prelude, and their internalisation may involve inspecting -- the AST. @@ -1610,7 +1598,7 @@ isIntrinsicFunction :: [E.Exp] -> SrcLoc -> Maybe (String -> InternaliseM [SubExp]) -isIntrinsicFunction qname args loc = do +isIntrinsicFunction qname all_args loc = do guard $ baseTag (qualLeaf qname) <= maxIntrinsicTag let handlers = [ handleSign, @@ -1620,7 +1608,7 @@ isIntrinsicFunction qname args loc = do handleAD, handleRest ] - msum [h args $ baseString $ qualLeaf qname | h <- handlers] + msum [h all_args $ baseString $ qualLeaf qname | h <- handlers] where handleSign [x] "sign_i8" = Just $ toSigned I.Int8 x handleSign [x] "sign_i16" = Just $ toSigned I.Int16 x @@ -1651,12 +1639,29 @@ isIntrinsicFunction qname args loc = do fmap pure $ letSubExp desc $ I.BasicOp $ I.ConvOp conv x' handleOps _ _ = Nothing - handleSOACs [lam, arr] "map" = Just $ \desc -> do - arr' <- internaliseExpToVars "map_arr" arr - arr_ts <- mapM lookupType arr' - lam' <- internaliseLambdaCoerce lam $ map rowType arr_ts - let w = arraysSize 0 arr_ts - letTupExp' desc $ I.Op $ I.Screma w arr' (I.mapSOAC lam') + handleSOACs (lam : args) "map" = Just $ \desc -> do + arg_ses <- concat <$> mapM (internaliseExp "arg") args + arg_ts <- mapM subExpType arg_ses + let param_ts = map rowType arg_ts + map_dim = head $ I.shapeDims $ I.arrayShape $ head arg_ts + + arg_ses' <- + zipWithM + ( \p a -> + ensureShape "" mempty (arrayOfRow p map_dim) "" a + ) + param_ts + arg_ses + + args_v'' <- mapM (letExp "" . BasicOp . SubExp) arg_ses' + + lambda <- internaliseLambdaCoerce lam param_ts + + letTupExp' + desc + $ Op + $ Screma map_dim args_v'' + $ mapSOAC lambda handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k Just $ \_desc -> do diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 8d2871adfa..582ebef4d1 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -20,14 +20,19 @@ -- still needed in monomorphisation for now. module Futhark.Internalise.FullNormalise (transformProg) where +import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor +import Data.List (zip4) import Data.List.NonEmpty qualified as NE import Data.Map qualified as M +import Data.Maybe import Data.Text qualified as T import Futhark.MonadFreshNames +import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.Primitive (intValue) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types @@ -212,13 +217,13 @@ getOrdering final (Lambda params body mte ret loc) = do nameExp final $ Lambda params body' mte ret loc getOrdering _ (OpSection qn ty loc) = pure $ Var qn ty loc -getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do +getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext, _), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do x <- getOrdering False e yn <- newNameFromString "y" let y = Var (qualName yn) (Info $ toStruct yty) mempty ret' = applySubst (pSubst x y) ret body = - mkApply (Var op ty mempty) [(xext, x), (Nothing, y)] $ + mkApply (Var op ty mempty) [(xext, mempty, x), (Nothing, mempty, y)] $ AppRes (toStruct ret') exts nameExp final $ Lambda [Id yn (Info yty) mempty] body Nothing (Info (RetType dims ret')) loc where @@ -226,12 +231,12 @@ getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext), Info (yp, yty)) (I | Named p <- xp, p == vn = Just $ ExpSubst x | Named p <- yp, p == vn = Just $ ExpSubst y | otherwise = Nothing -getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext)) (Info (RetType dims ret)) loc) = do +getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext, _)) (Info (RetType dims ret)) loc) = do xn <- newNameFromString "x" y <- getOrdering False e let x = Var (qualName xn) (Info $ toStruct xty) mempty ret' = applySubst (pSubst x y) ret - body = mkApply (Var op ty mempty) [(Nothing, x), (yext, y)] $ AppRes (toStruct ret') [] + body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, mempty, y)] $ AppRes (toStruct ret') [] nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret')) loc where pSubst x y vn @@ -307,7 +312,10 @@ getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do While e -> While <$> transformBody e body' <- transformBody body nameExp final $ AppExp (Loop sizes pat (LoopInitExplicit einit') form' body' loc) resT -getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info elp) (er, Info erp) loc) (Info resT)) = do +getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, _)) (er, Info (erp, _)) loc) (Info resT)) = do + -- Rewrite short-circuiting boolean operators on scalars to explicit + -- if-then-else. Automapped cases are turned into applications of + -- intrinsic functions. expr' <- case (isOr, isAnd) of (True, _) -> do el' <- naming "or_lhs" $ getOrdering True el @@ -320,7 +328,7 @@ getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info elp) (er, Info erp) lo (False, False) -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er - pure $ mkApply (Var op opT oloc) [(elp, el'), (erp, er')] resT + pure $ mkApply (Var op opT oloc) [(elp, mempty, el'), (erp, mempty, er')] resT nameExp final expr' where isOr = baseName (qualLeaf op) == "||" @@ -353,7 +361,7 @@ getOrdering final (AppExp (Match expr cs loc) resT) = do -- a complete separtion of states. transformBody :: (MonadFreshNames m) => Exp -> m Exp transformBody e = do - (e', pre_eval) <- runOrdering (getOrdering True e) + (e', pre_eval) <- runOrdering $ getOrdering True e pure $ foldl f e' pre_eval where appRes = case e of @@ -367,9 +375,325 @@ transformBody e = do transformValBind :: (MonadFreshNames m) => ValBind -> m ValBind transformValBind valbind = do - body' <- transformBody $ valBindBody valbind + body' <- transformBody <=< expandAMAnnotations $ valBindBody valbind pure $ valbind {valBindBody = body'} -- | Fully normalise top level bindings. transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] transformProg = mapM transformValBind + +--- | Expansion of 'AutoMap'-annotated applications. +--- +--- Each application @f x@ has an annotation with @AutoMap R M F@ where +--- @R, M, F@ are the autorep, automap, and frame shapes, +--- respectively. +--- +--- The application @f x@ will have type @F t@ for some @t@, i.e. @(f +--- x) : F t@. The frame @F@ is a prefix of the type of @f x@; namely +--- it is the total accumulated shape that is due to implicit maps. +--- Another way of thinking about that is that @|F|@ is is the level +--- of the automap-nest that @f x@ is in. For example, if @|F| = 2@ +--- then we know that @f x@ implicitly stands for +--- +--- > map (\x' -> map (\x'' -> f x'') x') x +--- +--- For an application with a non-empty autorep annotation, the frame +--- tells about how many dimensions of the replicate can be eliminated. +--- For example, @[[1,2],[3,4]] + 5@ will yield the following annotations: +--- +--- > ([[1,2],[3,4]] +) -- AutoMap {R = mempty, M = [2][2], F = [2][2]} +--- > (([[1,2],[3,4]] +) 5) -- AutoMap {R = [2][2], M = mempty, F = [2][2]} +--- +--- All replicated arguments are pushed down the auto-map nest. Each +--- time a replicated argument is pushed down a level of an +--- automap-nest, one fewer replicates is needed (i.e., the outermost +--- dimension of @R@ can be dropped). Replicated arguments are pushed +--- down the nest until either 1) the bottom of the nest is encountered +--- or 2) no replicate dimensions remain. For example, in the second +--- application above @R@ = @F@, so we can push the replicated argument +--- down two levels. Since each level effectively removes a dimension +--- of the replicate, no replicates will be required: +--- +--- > map (\xs -> map (\x -> f x'' 5) xs) [[1,2],[3,4]] +--- +--- The number of replicates that are actually required is given by +--- max(|R| - |F|, 0). +--- +--- An expression's "true level" is the level at which that expression +--- will appear in the automap-nest. The bottom of a mapnest is level 0. +--- +--- * For annotations with @R = mempty@, the true level is @|F|@. +--- * For annotations with @M = mempty@, the true level is @|F| - |R|@. +--- +--- If @|R| > |F|@ then actual replicates (namely @|R| - |F|@ of them) +--- will be required at the bottom of the mapnest. +--- +--- Note that replicates can only appear at the bottom of a mapnest; any +--- expression of the form +--- +--- > map (\ls x' rs -> e) (replicate x) +--- +--- can always be written as +--- +--- > map (\ls rs -> e[x' -> x]) +--- +--- Let's look at another example. Consider (with exact sizes omitted for brevity) +--- +--- > f : a -> a -> a -> []a -> [][][]a -> a +--- > xss : [][]a +--- > ys : []a +--- > zsss : [][][]a +--- > w : a +--- > vss : [][]a +--- +--- and the application +--- +--- > f xss ys zsss w vss +--- +--- which will have the following annotations +--- +--- > (f xss) -- AutoMap {R = mempty, M = [][], F = [][]} (1) +--- > ((f xss) ys) -- AutoMap {R = [], M = mempty, F = [][]} (2) +--- > (((f xss) ys) zsss) -- AutoMap {R = mempty, M = [], F = [][][]} (3) +--- > ((((f xss) ys) zsss) w) -- AutoMap {R = [][][][], M = mempty, F = [][][]} (4) +--- > (((((f xss) ys) zsss) w) vss) -- AutoMap {R = [], M = mempty, F = [][][]} (5) +--- +--- This will yield the following mapnest. +--- +--- > map (\zss -> +--- > map (\xs zs vs -> +--- > map (\x y z v -> f x y z (replicate w) v) xs ys zs v) xss zss vss) zsss +--- +--- Let's see how we'd construct this mapnest from the annotations. We construct +--- the nest bottom-up. We have: +--- +--- Application | True level +--- --------------------------- +--- (1) | |[][]| = 2 +--- (2) | |[][]| - |[]| = 1 +--- (3) | |[][][]| = 3 +--- (4) | |[][][]| - |[][][][]| = -1 +--- (5) | |[][][]| - |[]| = 2 +--- +--- We start at level 0. +--- * Any argument with a negative true level of @-n@ will be replicated @n@ times; +--- the exact shapes can be found by removing the @F@ postfix from @R@, +--- i.e. @R = shapes_to_rep_by <> F@. +--- * Any argument with a 0 true level will be included. +--- * For any argument @arg@ with a positive true level, we construct a new parameter +--- whose type is @arg@ with the leading @n@ dimensions (where @n@ is the true level) +--- removed. +--- +--- Following the rules above, @w@ will be replicated once. For the remaining arguments, +--- we create new parameters @x : a, y : a, z : a , v : a@. Hence, level 0 becomes +--- +--- > f x y z (replicate w) v +--- +--- At level l > 0: +--- * There are no replicates. +--- * Any argument with l true level will be included verbatim. +--- * Any argument with true level > l will have a new parameter constructed for it, +--- whose type has the leading @n - l@ dimensions (where @n@ is the true level) removed. +--- * We surround the previous level with a map that binds that levels' new parameters +--- and is passed the current levels' arguments. +--- +--- Following the above recipe for level 1, we create parameters +--- @xs : []a, zs : []a, vs :[]a@ and obtain +--- +--- > map (\x y z v -> f x y z (replicate w) v) xs ys zs vs +--- +--- This process continues until the level is greater than the maximum +--- true level of any application, at which we terminate. + +-- | Expands 'AutoMap' annotations into explicit @map@s and @replicates@. +expandAMAnnotations :: (MonadFreshNames m) => Exp -> m Exp +expandAMAnnotations e = + case e of + (AppExp (Apply f args _) (Info res)) + | ((exts, ams), arg_es) <- + first unzip $ unzip $ map (first unInfo) $ NE.toList args, + any (/= mempty) ams -> do + f' <- expandAMAnnotations f + arg_es' <- mapM expandAMAnnotations arg_es + let diets = funDiets $ typeOf f + withMapNest (zip4 exts ams arg_es' diets) $ \args' -> do + let rettype = + case unfoldFunTypeWithRet $ typeOf f' of + Nothing -> error "Function type expected." + Just (ptypes, f_ret) -> + let parsubsts = mapMaybe parSub $ zip ptypes args' + in applySubst (`lookup` parsubsts) $ + foldFunType (drop (length args') $ map snd ptypes) f_ret + when (appResExt res /= []) $ + error "expandAMAnnotations: cannot handle existential yet." + pure $ + mkApply f' (zip3 exts (repeat mempty) args') $ + res {appResType = rettype} + (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do + x' <- expandAMAnnotations x + y' <- expandAMAnnotations y + withMapNest [(xext, xam, x', Observe), (yext, yam, y', Observe)] $ \[x'', y''] -> + pure $ + AppExp + ( BinOp + op + (Info t) + (x'', Info (xext, mempty)) + (y'', Info (yext, mempty)) + loc + ) + (Info res {appResType = stripArray (shapeRank $ autoFrame yam) (appResType res)}) + _ -> astMap identityMapper {mapOnExp = expandAMAnnotations} e + where + parSub ((Named v, Scalar (Prim (Signed Int64))), arg) = + Just (v, ExpSubst arg) + parSub _ = Nothing + + funDiets :: TypeBase dim as -> [Diet] + funDiets (Scalar (Arrow _ _ d _ (RetType _ t2))) = d : funDiets t2 + funDiets _ = [] + +type Level = Int + +newtype AutoMapArg = AutoMapArg + { amArg :: Exp + } + deriving (Show) + +data AutoMapParam = AutoMapParam + { amParam :: Pat ParamType, + amMapDim :: Size, + amDiet :: Diet + } + deriving (Show) + +-- | Builds a map-nest based on the 'AutoMap' annotations. +withMapNest :: + forall m. + (MonadFreshNames m) => + [(Maybe VName, AutoMap, Exp, Diet)] -> + ([Exp] -> m Exp) -> + m Exp +withMapNest nest_args f = do + (param_map, arg_map) <- + bimap combineMaps combineMaps . unzip <$> mapM buildArgMap nest_args + buildMapNest param_map arg_map $ maximum $ M.keys arg_map + where + combineMaps :: (Ord k) => [M.Map k v] -> M.Map k [v] + combineMaps = M.unionsWith (<>) . (fmap . fmap) pure + + buildMapNest :: + M.Map Level [AutoMapParam] -> + M.Map Level [AutoMapArg] -> + Level -> + m Exp + buildMapNest _ arg_map 0 = + f $ map amArg $ arg_map M.! 0 + buildMapNest param_map arg_map l = + case map amMapDim $ param_map M.! l of + [] -> error "Malformed param map." + (map_dim : _) -> do + let params = map (\p -> (amDiet p, amParam p)) $ param_map M.! l + args = map amArg $ arg_map M.! l + body <- buildMapNest param_map arg_map (l - 1) + pure $ + mkMap params body args $ + RetType [] $ + arrayOfWithAliases Nonunique (Shape [map_dim]) (typeOf body) + + buildArgMap :: + (Maybe VName, AutoMap, Exp, Diet) -> + m (M.Map Level AutoMapParam, M.Map Level AutoMapArg) + buildArgMap (_ext, am, arg, arg_diet) = + foldM mkArgsAndParams mempty $ reverse [0 .. trueLevel am] + where + mkArgsAndParams (p_map, a_map) l + | l == 0 = do + let arg' = maybe arg (paramToExp . amParam) (p_map M.!? 1) + rarg <- mkReplicateShape (autoRep am `shapePrefix` autoFrame am) arg' + pure (p_map, M.insert 0 (AutoMapArg rarg) a_map) + | l == trueLevel am = do + p <- mkAMParam (typeOf arg) l + let d = outerDim am l + pure + ( M.insert l (AutoMapParam p d arg_diet) p_map, + M.insert l (AutoMapArg arg) a_map + ) + | l < trueLevel am && l > 0 = do + p <- mkAMParam (typeOf arg) l + let d = outerDim am l + let arg' = + paramToExp $ + amParam $ + p_map M.! (l + 1) + pure + ( M.insert l (AutoMapParam p d arg_diet) p_map, + M.insert l (AutoMapArg arg') a_map + ) + | otherwise = error "Impossible." + + mkAMParam t level = + mkParam ("p_" <> show level) $ argType (level - 1) am t + + trueLevel :: AutoMap -> Int + trueLevel am + | autoMap am == mempty = + max 0 $ shapeRank (autoFrame am) - shapeRank (autoRep am) + | otherwise = + shapeRank $ autoFrame am + + outerDim :: AutoMap -> Int -> Size + outerDim am level = + (!! (trueLevel am - level)) $ shapeDims $ autoFrame am + + argType level am = stripArray (trueLevel am - level) + +mkParam :: (MonadFreshNames m) => String -> TypeBase Size u -> m (Pat ParamType) +mkParam desc t = do + x <- newVName desc + pure $ Id x (Info $ toParam Observe t) mempty + +mkReplicateShape :: (MonadFreshNames m) => Shape Size -> Exp -> m Exp +mkReplicateShape s e = foldM (flip mkReplicate) e s + +mkReplicate :: (MonadFreshNames m) => Exp -> Exp -> m Exp +mkReplicate dim e = do + x <- mkParam "x" (Scalar $ Prim $ Unsigned Int64) + pure $ + mkMap [(Observe, x)] e [xs] $ + RetType mempty (arrayOfWithAliases Unique (Shape [dim]) (typeOf e)) + where + xs = + AppExp + ( Range + (Literal (UnsignedValue $ intValue Int64 (0 :: Int)) mempty) + Nothing + (UpToExclusive dim) + mempty + ) + ( Info $ AppRes (arrayOf (Shape [dim]) (Scalar $ Prim $ Unsigned Int64)) [] + ) + +mkMap :: [(Diet, Pat ParamType)] -> Exp -> [Exp] -> ResRetType -> Exp +mkMap params body arrs rettype = + mkApply mapN args (AppRes (toStruct $ retType rettype) []) + where + args = map (Nothing,mempty,) $ lambda : arrs + mapt = foldFunType (zipWith toParam (Observe : map fst params) (typeOf lambda : map typeOf arrs)) rettype + mapN = Var (QualName [] $ VName "map" 0) (Info mapt) mempty + lambda = + Lambda + (map snd params) + body + Nothing + ( Info $ + RetType + (retDims rettype) + (typeOf body `setUniqueness` uniqueness (retType rettype)) + ) + mempty + +paramToExp :: Pat ParamType -> Exp +paramToExp (Id vn (Info t) loc) = + Var (QualName [] vn) (Info $ toStruct t) loc +paramToExp p = error $ prettyString p diff --git a/src/Futhark/Internalise/LiftLambdas.hs b/src/Futhark/Internalise/LiftLambdas.hs index 68532bd9e6..335db53fd3 100644 --- a/src/Futhark/Internalise/LiftLambdas.hs +++ b/src/Futhark/Internalise/LiftLambdas.hs @@ -138,7 +138,7 @@ liftFunction fname tparams params (RetType dims ret) funbody = do apply _ f [] = f apply orig_type f (p : rem_ps) = let inner_ret = AppRes (augType rem_ps orig_type) mempty - inner = mkApply f [(Nothing, freeVar p)] inner_ret + inner = mkApply f [(Nothing, mempty, freeVar p)] inner_ret in apply orig_type inner rem_ps transformSubExps :: ASTMapper LiftM diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index 3058c93551..b67d2fe354 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -119,10 +119,10 @@ entryAssert (x : xs) body = andop = Var (qualName (intrinsicVar "&&")) (Info opt) mempty eqop = Var (qualName (intrinsicVar "==")) (Info opt) mempty logAnd x' y = - mkApply andop [(Nothing, x'), (Nothing, y)] $ + mkApply andop [(Nothing, mempty, x'), (Nothing, mempty, y)] $ AppRes bool [] cmpExp (ReplacedExp x', y) = - mkApply eqop [(Nothing, x'), (Nothing, y')] $ + mkApply eqop [(Nothing, mempty, x'), (Nothing, mempty, y')] $ AppRes bool [] where y' = Var (qualName y) (Info i64) mempty @@ -398,7 +398,7 @@ transformFName loc fname ft = do ( i - 1, mkApply f - [(Nothing, size_arg)] + [(Nothing, mempty, size_arg)] (AppRes (foldFunType (replicate i i64) (RetType [] t)) []) ) @@ -500,7 +500,7 @@ transformAppExp (Apply fe args _) res = <*> mapM onArg (NE.toList args) <*> transformAppRes res where - onArg (Info ext, e) = (ext,) <$> transformExp e + onArg (Info (ext, am), e) = (ext,am,) <$> transformExp e transformAppExp (Loop sparams pat loopinit form body loc) res = do e1' <- transformExp $ loopInitExp loopinit @@ -529,7 +529,7 @@ transformAppExp (Loop sparams pat loopinit form body loc) res = do (pat_sizes, pat'') <- sizesForPat pat' res' <- transformAppRes res pure $ AppExp (Loop (sparams' ++ pat_sizes) pat'' (LoopInitExplicit e1') form' body' loc) (Info res') -transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) res = do +transformAppExp (BinOp (fname, _) (Info t) (e1, Info (d1, _)) (e2, Info (d2, _)) loc) res = do (AppRes ret ext) <- transformAppRes res fname' <- transformFName loc fname (toStruct t) e1' <- transformExp e1 @@ -564,8 +564,8 @@ transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) res = do where applyOp ret ext fname' x y = mkApply - (mkApply fname' [(unInfo d1, x)] (AppRes ret mempty)) - [(unInfo d2, y)] + (mkApply fname' [(d1, mempty, x)] (AppRes ret mempty)) + [(d2, mempty, y)] (AppRes ret ext) makeVarParam arg = do @@ -651,7 +651,7 @@ transformExp (Lambda {}) = transformExp (OpSection qn t loc) = transformExp $ Var qn t loc transformExp (OpSectionLeft fname (Info t) e arg (Info rettype, Info retext) loc) = do - let (Info (xp, xtype, xargext), Info (yp, ytype)) = arg + let (Info (xp, xtype, xargext, _), Info (yp, ytype)) = arg e' <- transformExp e desugarBinOpSection fname @@ -663,7 +663,7 @@ transformExp (OpSectionLeft fname (Info t) e arg (Info rettype, Info retext) loc (rettype, retext) loc transformExp (OpSectionRight fname (Info t) e arg (Info rettype) loc) = do - let (Info (xp, xtype), Info (yp, ytype, yargext)) = arg + let (Info (xp, xtype), Info (yp, ytype, yargext, _)) = arg e' <- transformExp e desugarBinOpSection fname @@ -735,7 +735,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( let apply_left = mkApply op - [(xext, e1)] + [(xext, mempty, e1)] (AppRes (Scalar $ Arrow mempty yp (diet ytype) (toStruct ytype) (RetType [] $ toRes Nonunique t')) []) onDim (Var d typ _) | Named p <- xp, qualLeaf d == p = Var (qualName v1) typ loc @@ -744,7 +744,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( rettype' = first onDim rettype body <- scoping (S.fromList [v1, v2]) $ - mkApply apply_left [(yext, e2)] + mkApply apply_left [(yext, mempty, e2)] <$> transformAppRes (AppRes (toStruct rettype') retext) rettype'' <- transformRetTypeSizes (S.fromList [v1, v2]) $ RetType dims rettype' pure . wrap_left . wrap_right $ diff --git a/src/Futhark/Solve/BranchAndBound.hs b/src/Futhark/Solve/BranchAndBound.hs new file mode 100644 index 0000000000..258757113b --- /dev/null +++ b/src/Futhark/Solve/BranchAndBound.hs @@ -0,0 +1,74 @@ +module Futhark.Solve.BranchAndBound (branchAndBound) where + +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.LP (LP (..)) +import Futhark.Solve.Matrix +import Futhark.Solve.Simplex + +newtype Bound a = Bound (Maybe a, Maybe a) + deriving (Eq, Ord, Show) + +instance (Ord a) => Semigroup (Bound a) where + Bound (mlb1, mub1) <> Bound (mlb2, mub2) = + Bound (combine max mlb1 mlb2, combine min mub1 mub2) + where + combine _ Nothing b2 = b2 + combine _ b1 Nothing = b1 + combine c (Just b1) (Just b2) = Just $ c b1 b2 + +-- | Solves an LP with the additional constraint that all solutions +-- must be integral. Returns 'Nothing' if infeasible or unbounded. +branchAndBound :: + (Read a, Unbox a, RealFrac a, Show a) => + LP a -> + Maybe (a, Vector Int) +branchAndBound prob@(LP _ a d) = (zopt,) <$> mopt + where + (zopt, mopt) = step (S.singleton mempty) (negate $ read "Infinity") Nothing + step todo zlow opt + | S.null todo = (zlow, opt) + | otherwise = + let (next, rest) = S.deleteFindMin todo + in case simplexLP (mkProblem next) of + Nothing -> step rest zlow opt + Just (z, sol) + | z <= zlow -> step rest zlow opt + | V.all isInt sol -> + step rest z (Just $ V.map round sol) + | otherwise -> + let (idx, frac) = + V.head $ V.filter (not . isInt . snd) $ V.zip (V.generate (V.length sol) id) sol + new_todo = + S.fromList $ + filter + (/= next) + [ M.insertWith (<>) idx (Bound (Nothing, Just $ fromInteger $ floor frac)) next, + M.insertWith (<>) idx (Bound (Just $ fromInteger $ ceiling frac, Nothing)) next + ] + in step (new_todo <> rest) zlow opt + + -- TODO: use isInt x = x == round x + -- requires a better 'rowEchelon' implementation for matrices + isInt x = abs (fromIntegral (round x :: Int) - x) <= 10 ^^ ((-10) :: Int) + mkProblem = + M.foldrWithKey + ( \idx bound acc -> addBound acc idx bound + ) + prob + + addBound lp idx (Bound (mlb, mub)) = + lp + { lpA = a `addRows` new_rows, + lpd = d V.++ V.fromList new_ds + } + where + (new_rows, new_ds) = + unzip $ + catMaybes + [ (V.generate (ncols a) (\i -> if i == idx then (-1) else 0),) <$> (negate <$> mlb), + (V.generate (ncols a) (\i -> if i == idx then 1 else 0),) <$> mub + ] diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs new file mode 100644 index 0000000000..5c8f40fcd8 --- /dev/null +++ b/src/Futhark/Solve/GLPK.hs @@ -0,0 +1,60 @@ +module Futhark.Solve.GLPK (glpk) where + +import Control.Monad +import Data.Bifunctor +import Data.LinearProgram +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Futhark.Solve.LP qualified as F +import System.IO.Silently + +linearProgToGLPK :: (Ord v, Num a) => F.LinearProg v a -> LP v a +linearProgToGLPK prog = + LP + { direction = cOptType $ F.optType prog, + objective = cObj $ F.objective prog, + constraints = map cConstraint $ F.constraints prog, + varBounds = bounds, + varTypes = kinds + } + where + cOptType F.Maximize = Max + cOptType F.Minimize = Min + cObj = fst . cLSum + + cLSum (F.LSum m) = + ( M.mapKeys fromJust $ M.filterWithKey (\k _ -> isJust k) m, + fromMaybe 0 (m M.!? Nothing) + ) + + cConstraint (F.Constraint ctype l r) = + let (linfunc, c) = cLSum $ l F.~-~ r + bound = + case ctype of + F.Equal -> Equ (-c) + F.LessEq -> UBound (-c) + in Constr Nothing linfunc bound + + bounds = M.fromList $ (,LBound 0) <$> varList + kinds = M.fromList $ (,IntVar) <$> varList + + varList = S.toList $ F.vars prog + +glpk :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) +glpk lp = do + (output, res) <- capture $ glpk' lp + pure $ do + guard $ "PROBLEM HAS NO INTEGER FEASIBLE SOLUTION" `notElem` lines output + res + +glpk' :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) +glpk' lp + | F.isConstant (F.objective lp) -- FIXME + = + pure $ pure (0, M.fromList $ map (,0) $ S.toList $ F.vars lp) + | otherwise = do + (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp + pure $ bimap truncate (fmap truncate) <$> mres + where + opts = mipDefaults {msgLev = MsgAll} diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs new file mode 100644 index 0000000000..5011ece9fb --- /dev/null +++ b/src/Futhark/Solve/LP.hs @@ -0,0 +1,336 @@ +module Futhark.Solve.LP + ( LP (..), + LPE (..), + convert, + normalize, + var, + constant, + cval, + bin, + or, + min, + max, + oneIsZero, + (~+~), + (~-~), + (~*~), + (!), + neg, + linearProgToLP, + linearProgToLPE, + LSum (..), + LinearProg (..), + OptType (..), + Constraint (..), + Vars (..), + CType (..), + (~==~), + (~<=~), + (~>=~), + rowEchelonLPE, + isConstant, + ) +where + +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.Matrix (Matrix (..)) +import Futhark.Solve.Matrix qualified as Matrix +import Futhark.Util.Pretty +import Language.Futhark.Pretty +import Prelude hiding (max, min, or) + +-- | A linear program. 'LP c a d' represents the program +-- +-- > maximize c^T * a +-- > subject to a * x <= d +-- > x >= 0 +-- +-- The matrix 'a' is assumed to have linearly-independent rows. +data LP a = LP + { lpc :: Vector a, + lpA :: Matrix a, + lpd :: Vector a + } + deriving (Eq, Show) + +-- | Equational form of a linear program. 'LPE c a d' represents the +-- program +-- +-- > maximize c^T * a +-- > subject to a * x = d +-- > x >= 0 +-- +-- The matrix 'a' is assumed to have linearly-independent rows. +data LPE a = LPE + { pc :: Vector a, + pA :: Matrix a, + pd :: Vector a + } + deriving (Eq, Show) + +rowEchelonLPE :: (Unbox a, Fractional a, Ord a) => LPE a -> LPE a +rowEchelonLPE (LPE c a d) = + LPE c (Matrix.sliceCols (V.generate (ncols a) id) ad) (Matrix.getCol (ncols a) ad) + where + ad = + Matrix.filterRows + (V.any (Prelude./= 0)) + (Matrix.rowEchelon $ a Matrix.<|> Matrix.fromColVector d) + +-- | Converts an 'LP' into an equivalent 'LPE' by introducing slack +-- variables. +convert :: (Num a, Unbox a) => LP a -> LPE a +convert (LP c a d) = LPE c' a' d + where + a' = a Matrix.<|> Matrix.diagonal (V.replicate (Matrix.nrows a) 1) + c' = c V.++ V.replicate (Matrix.nrows a) 0 + +-- | Linear sum of variables. +newtype LSum v a = LSum {lsum :: Map (Maybe v) a} + deriving (Show, Eq) + +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where + pretty (LSum m) = + concatWith (surround " + ") + $ map + ( \(k, a) -> + case k of + Nothing -> pretty a + Just k' -> (if a == 1 then mempty else pretty a <> "*") <> prettyName k' + ) + $ M.toList m + +isConstant :: (Ord v) => LSum v a -> Bool +isConstant (LSum m) = M.keysSet m `S.isSubsetOf` S.singleton Nothing + +instance Functor (LSum v) where + fmap f (LSum m) = LSum $ fmap f m + +class Vars a v where + vars :: a -> Set v + +instance (Ord v) => Vars (LSum v a) v where + vars = S.fromList . catMaybes . M.keys . lsum + +-- | Type of constraint +data CType = Equal | LessEq + deriving (Show, Eq) + +instance Pretty CType where + pretty Equal = "==" + pretty LessEq = "<=" + +-- | A constraint for a linear program. +data Constraint v a + = Constraint CType (LSum v a) (LSum v a) + deriving (Show, Eq) + +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (Constraint v a) where + pretty (Constraint t l r) = + pretty l <+> pretty t <+> pretty r + +instance (Ord v) => Vars (Constraint v a) v where + vars (Constraint _ l r) = vars l <> vars r + +data OptType = Maximize | Minimize + deriving (Show, Eq) + +instance Pretty OptType where + pretty Maximize = "maximize" + pretty Minimize = "minimize" + +-- | A linear program. +data LinearProg v a = LinearProg + { optType :: OptType, + objective :: LSum v a, + constraints :: [Constraint v a] + } + deriving (Show, Eq) + +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where + pretty (LinearProg opt obj cs) = + vcat + [ pretty opt, + indent 2 $ pretty obj, + "subject to", + indent 2 $ vcat $ map pretty cs + ] + +instance (Ord v) => Vars (LinearProg v a) v where + vars lp = + vars (objective lp) + <> foldMap vars (constraints lp) + +bigM :: (Num a) => a +bigM = 2 ^ (10 :: Int) + +-- max{x, y} = z +max :: (Num a, Ord v) => v -> LSum v a -> LSum v a -> LSum v a -> [Constraint v a] +max b x y z = + [ z ~>=~ x, + z ~>=~ y, + z ~<=~ x ~+~ bigM ~*~ var b, + z ~<=~ y ~+~ bigM ~*~ (constant 1 ~-~ var b) + ] + +-- min{x, y} = z +min :: (Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] +min b x y z = + [ var z ~<=~ var x, + var z ~<=~ var y, + var z ~>=~ var x ~-~ bigM ~*~ (constant 1 ~-~ var b), + var z ~>=~ var y ~-~ bigM ~*~ var b + ] + +oneIsZero :: (Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] +oneIsZero (b1, x1) (b2, x2) = + mkC b1 x1 + <> mkC b2 x2 + <> [(var b1 ~+~ var b2) ~<=~ constant 1] + where + mkC b x = + [ var x ~<=~ bigM ~*~ var b + ] + +or :: (Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] +or b1 b2 c1 c2 = + mkC b1 c1 + <> mkC b2 c2 + <> [var b1 ~+~ var b2 ~<=~ constant 1] + where + mkC b (Constraint Equal l r) = + [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b), + l ~>=~ r ~-~ bigM ~*~ (constant 1 ~-~ var b) + ] + mkC b (Constraint LessEq l r) = + [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b) + ] + +bin :: (Num a) => v -> Constraint v a +bin v = Constraint LessEq (var v) (constant 1) + +(~==~) :: LSum v a -> LSum v a -> Constraint v a +l ~==~ r = Constraint Equal l r + +infix 4 ~==~ + +(~<=~) :: LSum v a -> LSum v a -> Constraint v a +l ~<=~ r = Constraint LessEq l r + +infix 4 ~<=~ + +(~>=~) :: (Num a) => LSum v a -> LSum v a -> Constraint v a +l ~>=~ r = Constraint LessEq (neg l) (neg r) + +infix 4 ~>=~ + +normalize :: (Eq a, Num a) => LSum v a -> LSum v a +normalize = LSum . M.filter (/= 0) . lsum + +var :: (Num a) => v -> LSum v a +var v = LSum $ M.singleton (Just v) 1 + +constant :: a -> LSum v a +constant = LSum . M.singleton Nothing + +cval :: (Num a, Ord v) => LSum v a -> a +cval = (! Nothing) + +(~+~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a +(LSum x) ~+~ (LSum y) = LSum $ M.unionWith (+) x y + +infixl 6 ~+~ + +(~-~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a +x ~-~ y = x ~+~ neg y + +infixl 6 ~-~ + +(~*~) :: (Num a) => a -> LSum v a -> LSum v a +a ~*~ s = fmap (a *) s + +infixl 7 ~*~ + +(!) :: (Num a, Ord v) => LSum v a -> Maybe v -> a +(LSum m) ! v = fromMaybe 0 (m M.!? v) + +neg :: (Num a) => LSum v a -> LSum v a +neg (LSum x) = LSum $ fmap negate x + +-- | Converts a linear program given with a list of constraints +-- into the standard form. +linearProgToLP :: + forall v a. + (Unbox a, Num a, Ord v) => + LinearProg v a -> + (LP a, Map Int v) +linearProgToLP (LinearProg otype obj cs) = + let c = mkRow $ convertObj otype obj + a = Matrix.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + in (LP c a d, idxMap) + where + cs' = foldMap (convertEqCType . splitConstraint) cs + idxMap = + M.fromList $ + zip [0 ..] $ + catMaybes $ + M.keys $ + mconcat $ + map (lsum . fst) cs' + mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) + + convertEqCType :: (CType, LSum v a, a) -> [(LSum v a, a)] + convertEqCType (Equal, s, a) = [(s, a), (neg s, negate a)] + convertEqCType (LessEq, s, a) = [(s, a)] + + splitConstraint :: Constraint v a -> (CType, LSum v a, a) + splitConstraint (Constraint ctype l r) = + let c = negate $ cval (l ~-~ r) + in (ctype, l ~-~ r ~-~ constant c, c) + + convertObj :: OptType -> LSum v a -> LSum v a + convertObj Maximize s = s + convertObj Minimize s = neg s + +-- | Converts a linear program given with a list of constraints +-- into the equational form. Assumes no <= constraints. +linearProgToLPE :: + forall v a. + (Unbox a, Num a, Ord v) => + LinearProg v a -> + (LPE a, Map Int v) +linearProgToLPE (LinearProg otype obj cs) = + let c = mkRow $ convertObj otype obj + a = Matrix.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + in (LPE c a d, idxMap) + where + cs' = map (checkOnlyEqType . splitConstraint) cs + idxMap = + M.fromList $ + zip [0 ..] $ + catMaybes $ + M.keys $ + mconcat $ + map (lsum . fst) cs' + mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) + + splitConstraint :: Constraint v a -> (CType, LSum v a, a) + splitConstraint (Constraint ctype l r) = + let c = negate $ cval (l ~-~ r) + in (ctype, l ~-~ r ~-~ constant c, c) + + checkOnlyEqType :: (CType, LSum v a, a) -> (LSum v a, a) + checkOnlyEqType (Equal, s, a) = (s, a) + checkOnlyEqType (ctype, _, _) = error $ show ctype + + convertObj :: OptType -> LSum v a -> LSum v a + convertObj Maximize s = s + convertObj Minimize s = neg s diff --git a/src/Futhark/Solve/Matrix.hs b/src/Futhark/Solve/Matrix.hs new file mode 100644 index 0000000000..39ec16a39e --- /dev/null +++ b/src/Futhark/Solve/Matrix.hs @@ -0,0 +1,330 @@ +module Futhark.Solve.Matrix + ( Matrix (..), + toList, + toLists, + fromRowVector, + fromColVector, + fromVectors, + fromLists, + (@), + (!), + sliceCols, + getColM, + getCol, + setCol, + sliceRows, + getRowM, + getRow, + (<|>), + (<->), + addRow, + addRows, + imap, + generate, + identity, + diagonal, + (<.>), + (.*), + (*.), + (.+.), + (.-.), + rowEchelon, + filterRows, + deleteRow, + deleteCol, + ) +where + +import Data.List qualified as L +import Data.Map qualified as M +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V + +-- A matrix represented as a 1D 'Vector'. +data Matrix a = Matrix + { elems :: Vector a, + nrows :: Int, + ncols :: Int + } + deriving (Eq) + +instance (Show a, Unbox a) => Show (Matrix a) where + show = + unlines . map show . toLists + +toList :: (Unbox a) => Matrix a -> [Vector a] +toList m = + map (\r -> V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] + +toLists :: (Unbox a) => Matrix a -> [[a]] +toLists m = + map (\r -> V.toList $ V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] + +fromRowVector :: (Unbox a) => Vector a -> Matrix a +fromRowVector v = + Matrix + { elems = v, + nrows = 1, + ncols = V.length v + } + +fromColVector :: (Unbox a) => Vector a -> Matrix a +fromColVector v = + Matrix + { elems = v, + nrows = V.length v, + ncols = 1 + } + +empty :: (Unbox a) => Matrix a +empty = Matrix mempty 0 0 + +fromVectors :: (Unbox a) => [Vector a] -> Matrix a +fromVectors [] = empty +fromVectors vs = + Matrix + { elems = V.concat vs, + nrows = length vs, + ncols = V.length $ head vs + } + +fromLists :: (Unbox a) => [[a]] -> Matrix a +fromLists xss = + Matrix + { elems = V.concat $ map V.fromList xss, + nrows = length xss, + ncols = length $ head xss + } + +class SelectCols a where + select :: Vector Int -> a -> a + (@) :: a -> Vector Int -> a + (@) = flip select + +infix 9 @ + +instance (Unbox a) => SelectCols (Vector a) where + select s v = V.map (v V.!) s + +instance (Unbox a) => SelectCols (Matrix a) where + select = sliceCols + +(!) :: (Unbox a) => Matrix a -> (Int, Int) -> a +m ! (r, c) = elems m V.! (ncols m * r + c) + +sliceCols :: (Unbox a) => Vector Int -> Matrix a -> Matrix a +sliceCols cols m = + Matrix + { elems = + V.generate (nrows m * V.length cols) $ \i -> + let col = cols V.! (i `rem` V.length cols) + row = i `div` V.length cols + in m ! (row, col), + nrows = nrows m, + ncols = V.length cols + } + +getColM :: (Unbox a) => Int -> Matrix a -> Matrix a +getColM col = sliceCols $ V.singleton col + +getCol :: (Unbox a) => Int -> Matrix a -> Vector a +getCol col = elems . getColM col + +setCol :: (Unbox a) => Int -> Vector a -> Matrix a -> Matrix a +setCol c col m = + m + { elems = + V.update_ (elems m) indices col + } + where + indices = V.generate (nrows m) $ + \r -> r * ncols m + c + +sliceRows :: (Unbox a) => Vector Int -> Matrix a -> Matrix a +sliceRows rows m = + Matrix + { elems = + V.generate (ncols m * V.length rows) $ \i -> + let row = rows V.! (i `rem` V.length rows) + col = i `div` V.length rows + in m ! (row, col), + nrows = V.length rows, + ncols = ncols m + } + +getRowM :: (Unbox a) => Int -> Matrix a -> Matrix a +getRowM row = sliceRows $ V.singleton row + +getRow :: (Unbox a) => Int -> Matrix a -> Vector a +getRow row = elems . getRowM row + +(<|>) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a +m1 <|> m2 = + generate f (nrows m1) (ncols m1 + ncols m2) + where + f r c + | c < ncols m1 = m1 ! (r, c) + | otherwise = m2 ! (r, c - ncols m1) + +(<->) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a +m1 <-> m2 = + generate f (nrows m1 + nrows m2) (ncols m1) + where + f r c + | r < nrows m1 = m1 ! (r, c) + | otherwise = m2 ! (r - nrows m1, c) + +addRow :: (Unbox a) => Matrix a -> Vector a -> Matrix a +addRow m v = + m + { elems = elems m V.++ v, + nrows = nrows m + 1 + } + +addRows :: (Unbox a) => Matrix a -> [Vector a] -> Matrix a +addRows = foldl addRow + +imap :: (Unbox a) => (Int -> Int -> a -> a) -> Matrix a -> Matrix a +imap f m = + m + { elems = V.imap g $ elems m + } + where + g i = + let r = i `div` ncols m + c = i `rem` nrows m + in f r c + +generate :: (Unbox a) => (Int -> Int -> a) -> Int -> Int -> Matrix a +generate f rows cols = + Matrix + { elems = + V.generate (rows * cols) $ \i -> + let r = i `div` cols + c = i `rem` cols + in f r c, + nrows = rows, + ncols = cols + } + +identity :: (Unbox a, Num a) => Int -> Matrix a +identity n = generate (\r c -> if r == c then 1 else 0) n n + +diagonal :: (Unbox a, Num a) => Vector a -> Matrix a +diagonal d = generate (\r c -> if r == c then d V.! r else 0) (V.length d) (V.length d) + +(<.>) :: (Unbox a, Num a) => Vector a -> Vector a -> a +v1 <.> v2 = V.sum $ V.zipWith (*) v1 v2 + +infixl 7 <.> + +(*.) :: (Unbox a, Num a) => Matrix a -> Vector a -> Vector a +m *. v = + V.generate (nrows m) $ \r -> + getRow r m <.> v + +infixl 7 *. + +(.*) :: (Unbox a, Num a) => Vector a -> Matrix a -> Vector a +v .* m = + V.generate (ncols m) $ \c -> + v <.> getCol c m + +infixl 7 .* + +(.-.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a +(.-.) = V.zipWith (-) + +infixl 6 .-. + +(.+.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a +(.+.) = V.zipWith (+) + +infixl 6 .+. + +swapRows :: (Unbox a) => Int -> Int -> Matrix a -> Matrix a +swapRows r1 r2 m = + m + { elems = + elems m `V.update` new + } + where + start1 = ncols m * r1 + start2 = ncols m * r2 + row1 = getRow r1 m + row2 = getRow r2 m + new = + V.imap (\i a -> (i + start1, a)) row2 + V.++ V.imap (\i a -> (i + start2, a)) row1 + +-- todo: fix +update :: (Unbox a) => Matrix a -> Vector ((Int, Int), a) -> Matrix a +update m upds = + generate + ( \i j -> + case M.fromList (V.toList upds) M.!? (i, j) of + Nothing -> m ! (i, j) + Just x -> x + ) + (nrows m) + (ncols m) + +-- This version doesn't maintain integrality of the entries. +rowEchelon :: (Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a +rowEchelon = rowEchelon' 0 0 + where + rowEchelon' h k m@(Matrix _ nr nc) + | h < nr && k < nc = + if m ! (pivot_row, k) == 0 + then rowEchelon' h (k + 1) m + else rowEchelon' (h + 1) (k + 1) clear_rows_below + | otherwise = m + where + pivot_row = + fst $ + L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ + [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] + m' = swapRows h pivot_row m + clear_rows_below = + update m' $ + V.fromList $ + [((i, k), 0) | i <- [h + 1 .. nr - 1]] + ++ [ ((i, j), m' ! (i, j) - (m' ! (h, j)) * f) + | i <- [h + 1 .. nr - 1], + let f = m' ! (i, k) / m' ! (h, k), + j <- [k + 1 .. nc - 1] + ] + +-- TODO: fix. Something's wrong here, causes huge blow-up. +-- rowEchelon :: (Num a, Unbox a, Ord a) => Matrix a -> Matrix a +-- rowEchelon = rowEchelon' 0 0 +-- where +-- rowEchelon' h k m@(Matrix _ nr nc) +-- | h < nr && k < nc = +-- if m ! (pivot_row, k) == 0 +-- then rowEchelon' h (k + 1) m +-- else rowEchelon' (h + 1) (k + 1) clear_rows_below +-- | otherwise = m +-- where +-- pivot_row = +-- fst $ +-- L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ +-- [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] +-- m' = swapRows h pivot_row m +-- clear_rows_below = +-- update m' $ +-- V.fromList $ +-- [((i, k), 0) | i <- [h + 1 .. nr - 1]] +-- ++ [ ((i, j), (m' ! (h, k)) * (m' ! (i, j)) - (m' ! (h, j)) * (m' ! (i, k))) +-- | i <- [h + 1 .. nr - 1], +-- j <- [k + 1 .. nc - 1] +-- ] + +filterRows :: (Unbox a) => (Vector a -> Bool) -> Matrix a -> Matrix a +filterRows p = fromVectors . filter p . toList + +deleteRow :: (Unbox a) => Int -> Matrix a -> Matrix a +deleteRow n m = sliceRows (V.generate (nrows m - 1) (\r -> if r < n then r else r + 1)) m + +deleteCol :: (Unbox a) => Int -> Matrix a -> Matrix a +deleteCol n m = sliceCols (V.generate (ncols m - 1) (\c -> if c < n then c else c + 1)) m diff --git a/src/Futhark/Solve/Simplex.hs b/src/Futhark/Solve/Simplex.hs new file mode 100644 index 0000000000..362b300038 --- /dev/null +++ b/src/Futhark/Solve/Simplex.hs @@ -0,0 +1,235 @@ +module Futhark.Solve.Simplex + ( simplex, + simplexLP, + simplexProg, + findBasis, + ) +where + +import Data.List qualified as L +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as M +import Data.Maybe +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.LP (LP (..), LPE (..), LinearProg (..), convert, linearProgToLPE, rowEchelonLPE) +import Futhark.Solve.Matrix + +-- | A tableau of an equational linear program @a * x = d@ is +-- +-- > x @ b = p + q * x @ n +-- > --------------------- +-- > z = z' + r^T * x @ n +-- +-- where @z = c^T * x@ and @b@ (@n@) is a vector containing the +-- indices of basic (nonbasic) variables. +-- +-- The basic feasible solution corresponding to the above tableau is +-- given by @x \@ b = p@, @x \@n = 0@ with the value of the objective +-- equal to @z'@. + +-- | Computes @r@ as given in the tableau above. +compR :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Vector Int -> + Vector Int -> + Vector a +compR (LPE c a _) invA_B b n = + c @ n .-. c @ b .* invA_B .* a @ n + +-- | @compQEnter prob invA_B b n enter@ computes the @enter@th +-- column of @q@. +compQEnter :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Int -> + Vector a +compQEnter (LPE _ a _) invA_B enter = + V.map negate $ invA_B *. getCol enter a + +-- | Computes the objective given an inversion of @a@ and a basis. +compZ :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Vector Int -> + a +compZ (LPE c _ d) invA_B b = + c @ b .* invA_B <.> d + +-- | Constructs an auxiliary equational linear program to compute the +-- initial feasible basis; returns the program along with a feasible +-- basis. +mkAux :: (Ord a, Unbox a, Num a) => LPE a -> (LPE a, Vector Int, Vector Int) +mkAux (LPE _ a d) = (LPE c_aux a_aux d_aux, b_aux, n_aux) + where + c_aux = V.replicate (ncols a) 0 V.++ V.replicate (nrows a) (-1) + d_aux = V.map abs d + a_aux = + imap (\r _ e -> if (d V.! r) < 0 then negate e else e) a + <|> identity (nrows a) + b_aux = V.generate (nrows a) (+ ncols a) + n_aux = V.generate (ncols a) id + +fixDegenerateBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Int -> + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (LPE a, Matrix a, Vector a, Vector Int, Vector Int) +fixDegenerateBasis og_prob col prob (invA_B, p, b, n) + | Just exit_idx <- mexit_idx, + V.null (elim_row exit_idx) = + let prob' = + prob + { pA = deleteRow exit_idx (pA prob), + pd = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) $ + pd prob + } + invA_B' = deleteRow exit_idx $ deleteCol exit_idx invA_B + p' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) p + b' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) b + in fixDegenerateBasis og_prob col prob' (invA_B', p', b', n) + | Just exit_idx <- mexit_idx, + (enter, _) <- V.head (elim_row exit_idx) = + let enter_idx = fromJust $ V.findIndex (== enter) n + exit = b V.! exit_idx + in fixDegenerateBasis og_prob col prob $ + pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = + let prob' = + prob + { pc = pc og_prob, + pA = sliceCols (V.generate col id) $ pA prob, + pd = V.map abs $ pd og_prob + } + in (prob', invA_B, p, V.filter (< col) b, V.filter (< col) n) + where + mexit_idx = + fst <$> V.filter ((>= col) . snd) (V.imap (curry id) b) V.!? 0 + elim_row exit_idx = + V.filter ((/= 0) . snd) $ + V.map (\j -> (j, compQEnter prob invA_B j V.! exit_idx)) $ + V.generate col id + +-- | Finds an initial feasible basis for an equational linear program. +-- Returns 'Nothing' if the LP has no solution. Inverts some +-- equations by multiplying by -1 so it also returns a modified (but +-- equivalent) equational linear program. +findBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Maybe (LPE a, Matrix a, Vector a, Vector Int, Vector Int) +findBasis prob = do + (invA_B, p, b, n) <- step p_aux (invA_B_aux, d_aux, b_aux, n_aux) + if compZ p_aux invA_B b == 0 + then Just $ fixDegenerateBasis prob (ncols $ pA prob) p_aux (invA_B, p, b, n) + else Nothing + where + (p_aux@(LPE _ _ d_aux), b_aux, n_aux) = mkAux prob + invA_B_aux = identity $ V.length b_aux + +-- | Solves an equational linear program. Returns 'Nothing' if the +-- program is infeasible or unbounded. Otherwise returns the optimal +-- value and the solution. +simplex :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Maybe (a, Vector a) +simplex lpe = do + (lpe', invA_B, p, b, n) <- findBasis $ rowEchelonLPE lpe + (invA_B', p', b', n') <- step lpe' (invA_B, p, b, n) + let z = compZ lpe' invA_B' b' + sol = + V.map snd $ + V.fromList $ + L.sortOn fst $ + V.toList $ + V.zip (b' V.++ n') (p' V.++ V.replicate (V.length n') 0) + pure (z, sol) + +-- | Solves a linear program. +simplexLP :: + (Unbox a, Ord a, Fractional a, Show a) => + LP a -> + Maybe (a, Vector a) +simplexLP lp = do + (opt, sol) <- simplex lpe + pure (opt, V.take (ncols $ lpA lp) sol) + where + lpe = convert lp + +simplexProg :: + (Unbox a, Ord a, Ord v, Fractional a, Show a) => + LinearProg v a -> + Maybe (a, Map v a) +simplexProg prog = do + (z, sol) <- simplex lpe + pure (z, M.fromList $ zipWith (\i x -> (idxMap M.! i, x)) [0 ..] $ V.toList sol) + where + (lpe, idxMap) = linearProgToLPE prog + +pivot :: + (Unbox a, Fractional a) => + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (Int, Int) -> + (Int, Int) -> + (Matrix a, Vector a, Vector Int, Vector Int) +pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) = + (invA_B', p', b', n') + where + q_enter = compQEnter prob invA_B enter + b' = b V.// [(exit_idx, enter)] + n' = n V.// [(enter_idx, exit)] + e_inv_vec = + V.map + (/ abs (q_enter V.! exit_idx)) + (q_enter V.// [(exit_idx, 1)]) + genF row col = + (if row == exit_idx then 0 else invA_B ! (row, col)) + + (e_inv_vec V.! row) * invA_B ! (exit_idx, col) + invA_B' = generate genF (nrows invA_B) (ncols invA_B) + p' = p V.// [(exit_idx, 0)] .+. V.map (* (p V.! exit_idx)) e_inv_vec + +-- | One step of the simplex algorithm. +step :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + Maybe (Matrix a, Vector a, Vector Int, Vector Int) +step prob (invA_B, p, b, n) + | Just enter_idx <- menter_idx = + let enter = n V.! enter_idx + q_enter = compQEnter prob invA_B enter + pq = + V.map (\(i, p_', q_) -> (i, -(p_' / q_))) $ + V.filter (\(_, _, q_) -> q_ < 0) $ + V.zip3 (V.generate (V.length q_enter) id) p q_enter + in if V.null pq + then Nothing + else + let exit_val = snd $ V.minimumOn snd pq + exit_cands = + V.map fst $ V.filter ((exit_val ==) . snd) pq + (exit_idx, exit) = + V.minimumOn snd $ + V.map (\i -> (i, b V.! i)) exit_cands + in step prob $ pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = Just (invA_B, p, b, n) + where + r = compR prob invA_B b n + menter_idx = V.findIndex (> 0) r diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 009d9ac4e4..a0f99a9647 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -422,6 +422,11 @@ fromArray :: Value -> (ValueShape, [Value]) fromArray (ValueArray shape as) = (shape, elems as) fromArray v = error $ "Expected array value, but found: " <> show v +fromArrayR :: Int -> Value -> [Value] +fromArrayR 0 v = [v] +fromArrayR 1 v = snd $ fromArray v +fromArrayR n v = concatMap (fromArrayR (n - 1)) $ snd $ fromArray v + apply :: SrcLoc -> Env -> Value -> Value -> EvalM Value apply loc env (ValueFun f) v = stacking loc env (f v) apply _ _ f _ = error $ "Cannot apply non-function: " <> show f @@ -431,6 +436,35 @@ apply2 loc env f x y = stacking loc env $ do f' <- apply noLoc mempty f x apply noLoc mempty f' y +data AutoMapArg + = AutoMapArg [Int64] [Int64] [Int64] + deriving (Eq, Ord, Show) + +applyAM :: + SrcLoc -> + Env -> + (Value, StructType) -> + AutoMapArg -> + Value -> + EvalM Value +applyAM loc env (f, _) (AutoMapArg [] [] []) v = + apply loc env f v +applyAM loc env (f, ft) am@(AutoMapArg repshape mapshape frame) v = do + let v' = repArray repshape v + f' = repArray mapshape f + rank = length frame + vs = fromArrayR rank v' + fs = fromArrayR rank f' + t' <- evalType (eval env) mempty ft + case t' of + Scalar (Arrow _ _ _ _ (RetType _ ret_t)) + | Just rowshape <- sequenceA $ structTypeShape $ toStruct ret_t -> + toArrayR frame rowshape <$> zipWithM (apply loc env) fs vs + _ -> + error $ + "Invalid automap arguments:\n" + ++ unlines [prettyString ft, show f, show v, show am] + matchPat :: Env -> Pat (TypeBase Size u) -> Value -> EvalM Env matchPat env p v = do m <- runMaybeT $ patternMatch env p v @@ -762,13 +796,15 @@ evalFunctionBinding env tparams ps ret fbody = do returned env (retType ret) retext =<< evalFunction env' missing_sizes ps fbody (retType ret) -evalArg :: Env -> Exp -> Maybe VName -> EvalM Value -evalArg env e ext = do +evalArg :: Env -> Exp -> Maybe VName -> AutoMap -> EvalM (Value, AutoMapArg) +evalArg env e ext (AutoMap rshape mshape frame) = do v <- eval env e case ext of Just ext' -> putExtSize ext' v _ -> pure () - pure v + let evalShape = mapM (fmap asInt64 . eval env) . shapeDims + am' <- AutoMapArg <$> evalShape rshape <*> evalShape mshape <*> evalShape frame + pure (v, am') returned :: Env -> TypeBase Size als -> [VName] -> Value -> EvalM Value returned _ _ [] v = pure v @@ -838,22 +874,31 @@ evalAppExp env (LetPat sizes p e body _) = do evalAppExp env (LetFun f (tparams, ps, _, Info ret, fbody) body _) = do binding <- evalFunctionBinding env tparams ps ret fbody eval (env {envTerm = M.insert f binding $ envTerm env}) body -evalAppExp env (BinOp (op, _) op_t (x, Info xext) (y, Info yext) loc) - | baseString (qualLeaf op) == "&&" = do +evalAppExp env (BinOp (op, _) (Info op_t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) + | baseString (qualLeaf op) == "&&", + noAutoMap = do x' <- asBool <$> eval env x if x' then eval env y else pure $ ValuePrim $ BoolValue False - | baseString (qualLeaf op) == "||" = do + | baseString (qualLeaf op) == "||", + noAutoMap = do x' <- asBool <$> eval env x if x' then pure $ ValuePrim $ BoolValue True else eval env y | otherwise = do - x' <- evalArg env x xext - y' <- evalArg env y yext - op' <- eval env $ Var op op_t loc - apply2 loc env op' x' y' + (x', xam') <- evalArg env x xext xam + (y', yam') <- evalArg env y yext yam + op' <- evalTermVar env op op_t + op'' <- applyAM loc env (op', op_t) xam' x' + applyAM loc env (op'', op_ret) yam' y' + where + op_ret = case op_t of + Scalar (Arrow _ _ _ _ (RetType _ t)) -> + toStruct t + _ -> error $ "Nonsensical binop type: " <> prettyString op_t + noAutoMap = xam == mempty && yam == mempty evalAppExp env (If cond e1 e2 _) = do cond' <- asBool <$> eval env cond if cond' then eval env e1 else eval env e2 @@ -863,9 +908,11 @@ evalAppExp env (Apply f args loc) = do -- type of the functions. args' <- reverse <$> mapM evalArg' (reverse $ NE.toList args) f' <- eval env f - foldM (apply loc env) f' args' + foldM apply' f' args' where - evalArg' (Info ext, x) = evalArg env x ext + ft = expandType env $ typeOf f + apply' f' (v', am') = applyAM loc env (f', ft) am' v' + evalArg' (Info (ext, am), x) = evalArg env x ext am evalAppExp env (Index e is loc) = do is' <- mapM (evalDimIndex env) is arr <- eval env e @@ -1043,16 +1090,21 @@ eval env (Lambda ps body _ (Info (RetType _ rt)) _) = evalFunction env [] ps body rt eval env (OpSection qv (Info t) _) = evalTermVar env qv $ toStruct t -eval env (OpSectionLeft qv _ e (Info (_, _, argext), _) (Info (RetType _ t), _) loc) = do - v <- evalArg env e argext - f <- evalTermVar env qv (toStruct t) - apply loc env f v -eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext)) (Info (RetType _ t)) loc) = do - y <- evalArg env e argext +eval env (OpSectionLeft qv _ e (Info (_, _, argext, am), _) (Info (RetType _ t), _) loc) = do + (v, am') <- evalArg env e argext am + f <- evalTermVar env qv t' + applyAM loc env (f, t') am' v + where + t' = toStruct t +eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext, am)) (Info (RetType _ t)) loc) = do + (y, am') <- evalArg env e argext am pure $ ValueFun $ \x -> do - f <- evalTermVar env qv $ toStruct t - apply2 loc env f x y + f <- evalTermVar env qv t' + f' <- apply loc env f x + applyAM loc env (f', t') am' y + where + t' = toStruct t eval env (IndexSection is _ loc) = do is' <- mapM (evalDimIndex env) is pure $ ValueFun $ evalIndex loc env is' @@ -1619,22 +1671,6 @@ initialCtx = Just $ fun2 stream def s | "reduce_stream" `isPrefixOf` s = Just $ fun3 $ \_ f arg -> stream f arg - def "map" = Just $ - TermPoly Nothing $ \t eval' -> do - t' <- evalType eval' mempty t - pure $ ValueFun $ \f -> pure . ValueFun $ \xs -> - case unfoldFunType t' of - ([_, _], ret_t) - | Just rowshape <- typeRowShape ret_t -> - toArray' rowshape <$> mapM (apply noLoc mempty f) (snd $ fromArray xs) - | otherwise -> - error $ "Bad return type: " <> prettyString ret_t - _ -> - error $ - "Invalid arguments to map intrinsic:\n" - ++ unlines [prettyString t, show f, show xs] - where - typeRowShape = sequenceA . structTypeShape . stripArray 1 def s | "reduce" `isPrefixOf` s = Just $ fun3 $ \f ne xs -> foldM (apply2 noLoc mempty f) ne $ snd $ fromArray xs @@ -2175,7 +2211,7 @@ checkEntryArgs entry args entry_t "Got input of types" indent 2 (stack (map pretty args_ts)) where - (param_ts, _) = unfoldFunType entry_t + param_ts = map snd $ fst $ unfoldFunType entry_t args_ts = map (valueStructType . valueType) args expected | null param_ts = diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index 26372660c3..5526477ade 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE LambdaCase #-} + -- | The value representation used in the interpreter. -- -- Kept simple and free of unnecessary operational details (in @@ -22,7 +24,9 @@ module Language.Futhark.Interpreter.Values prettyEmptyArray, toArray, toArray', + toArrayR, toTuple, + repArray, -- * Conversion fromDataValue, @@ -31,7 +35,7 @@ where import Data.Array import Data.Bifunctor (Bifunctor (second)) -import Data.List (genericLength) +import Data.List (genericLength, genericReplicate) import Data.Map qualified as M import Data.Maybe import Data.Monoid hiding (Sum) @@ -253,6 +257,17 @@ toArray' rowshape vs = ValueArray shape (listArray (0, length vs - 1) vs) where shape = ShapeDim (genericLength vs) rowshape +-- | Produce multidimensional array from a flat list of values. +toArrayR :: [Int64] -> ValueShape -> [Value m] -> Value m +toArrayR [] _ = \case + [v] -> v + _ -> error "toArrayR: empty shape" +toArrayR [_] elemshape = toArray' elemshape +toArrayR (n : ns) elemshape = + toArray (foldr ShapeDim elemshape (n : ns)) + . map (toArrayR ns elemshape) + . chunk (fromIntegral (product ns)) + arrayLength :: (Integral int) => Array Int (Value m) -> int arrayLength = fromIntegral . (+ 1) . snd . bounds @@ -284,6 +299,13 @@ fromDataValueWith f shape vector where shape' = SVec.tail shape +repArray :: [Int64] -> Value m -> Value m +repArray [] v = v +repArray (n : ns) v = + toArray' (valueShape v') (genericReplicate n v') + where + v' = repArray ns v + -- | Convert a Futhark value in the externally observable data format -- to an interpreter value. fromDataValue :: V.Value -> Value m diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index ceab742ebf..7da47c0385 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -7,6 +7,7 @@ module Language.Futhark.Pretty prettyTuple, leadingOperator, IsName (..), + prettyNameText, prettyNameString, Annot (..), ) @@ -55,9 +56,13 @@ instance IsName Name where prettyName = pretty toName = id +-- | Prettyprint name as text. +prettyNameText :: (IsName v) => v -> T.Text +prettyNameText = docText . prettyName + -- | Prettyprint name as string. Only use this for debugging. prettyNameString :: (IsName v) => v -> String -prettyNameString = T.unpack . docText . prettyName +prettyNameString = T.unpack . prettyNameText -- | Class for type constructors that represent annotations. Used in -- the prettyprinter to either print the original AST, or the computed @@ -153,7 +158,7 @@ instance (Pretty (Shape dim), Pretty u) => Pretty (ScalarTypeBase dim u) where prettyType :: (Pretty (Shape dim), Pretty u) => Int -> TypeBase dim u -> Doc a prettyType _ (Array u shape at) = - pretty u <> pretty shape <> align (prettyScalarType 1 at) + pretty u <> pretty shape <> align (prettyScalarType 2 at) prettyType p (Scalar t) = prettyScalarType p t @@ -229,7 +234,13 @@ letBody body@(AppExp LetFun {} _) = pretty body letBody body = "in" <+> align (pretty body) prettyAppExp :: (Eq vn, IsName vn, Annot f) => Int -> AppExpBase f vn -> Doc a -prettyAppExp p (BinOp (bop, _) _ (x, _) (y, _) _) = prettyBinOp p bop x y +prettyAppExp p (BinOp (bop, _) _ (x, xi) (y, yi) _) = + case (unAnnot xi, unAnnot yi) of + (Just (_, xam), Just (_, yam)) + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 3 -> + -- fix + parens $ align $ prettyBinOp p bop x y "Δ" <+> pretty xam "Δ" <+> pretty yam + _ -> prettyBinOp p bop x y prettyAppExp _ (Match e cs _) = "match" <+> pretty e (stack . map pretty) (NE.toList cs) prettyAppExp _ (Loop sizeparams pat initexp form loopbody _) = "loop" @@ -306,11 +317,21 @@ prettyAppExp _ (If c t f _) = prettyAppExp p (Apply f args _) = parensIf (p >= 10) $ prettyExp 0 f - <+> hsep (map (prettyExp 10 . snd) $ NE.toList args) + <+> hsep (map prettyArg $ NE.toList args) + where + prettyArg (i, e) = + case unAnnot i of + Just (_, am) + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 3 -> + parens (prettyExp 10 e <+> "Δ" <+> pretty am) + _ -> prettyExp 10 e instance (Eq vn, IsName vn, Annot f) => Pretty (AppExpBase f vn) where pretty = prettyAppExp (-1) +instance Pretty AutoMap where + pretty (AutoMap r m f) = encloseSep lparen rparen comma $ map pretty [r, m, f] + prettyInst :: (Annot f, Pretty t) => f t -> Doc a prettyInst t = case unAnnot t of diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 070dfc733f..426b6b8f8a 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -34,6 +34,9 @@ module Language.Futhark.Prop subExps, similarExps, sameExp, + frameOf, + shapePrefix, + typeShapePrefix, -- * Queries on patterns and params patIdents, @@ -52,9 +55,11 @@ module Language.Futhark.Prop arrayShape, orderZero, unfoldFunType, + unfoldFunTypeWithRet, foldFunType, typeVars, isAccType, + recordField, -- * Operations on types peelArray, @@ -249,6 +254,14 @@ diet (Array d _ _) = d diet (Scalar (TypeVar d _ _)) = d diet (Scalar (Sum cs)) = foldl max Observe $ foldMap (map diet) cs +-- | Look up this record field if it exists. +recordField :: [Name] -> TypeBase dim u -> Maybe (TypeBase dim u) +recordField [] t = Just t +recordField (f : fs) (Scalar (Record fts)) + | Just ft <- M.lookup f fts = + recordField fs ft +recordField _ _ = Nothing + -- | Convert any type to one that has rank information, no alias -- information, and no embedded names. toStructural :: @@ -317,7 +330,9 @@ arrayOfWithAliases :: arrayOfWithAliases u shape2 (Array _ shape1 et) = Array u (shape2 <> shape1) et arrayOfWithAliases u shape (Scalar t) = - Array u shape (second (const mempty) t) + if shapeRank shape == 0 + then Scalar t `setUniqueness` u + else Array u shape (second (const mempty) t) -- | @stripArray n t@ removes the @n@ outermost layers of the array. -- Essentially, it is the type of indexing an array of type @t@ with @@ -495,7 +510,7 @@ typeOf (Attr _ e _) = typeOf e typeOf (AppExp _ (Info res)) = appResType res -- | The type of a function with the given parameters and return type. -funType :: [Pat ParamType] -> ResRetType -> StructType +funType :: [Pat (TypeBase d Diet)] -> RetTypeBase d Uniqueness -> TypeBase d NoUniqueness funType params ret = let RetType _ t = foldr (arrow . patternParam) ret params in toStruct t @@ -505,7 +520,7 @@ funType params ret = -- | @foldFunType ts ret@ creates a function type ('Arrow') that takes -- @ts@ as parameters and returns @ret@. -foldFunType :: [ParamType] -> ResRetType -> StructType +foldFunType :: [TypeBase d Diet] -> RetTypeBase d Uniqueness -> TypeBase d NoUniqueness foldFunType ps ret = let RetType _ t = foldr arrow ret ps in toStruct t @@ -515,12 +530,24 @@ foldFunType ps ret = -- | Extract the parameter types and return type from a type. -- If the type is not an arrow type, the list of parameter types is empty. -unfoldFunType :: TypeBase dim as -> ([TypeBase dim Diet], TypeBase dim NoUniqueness) -unfoldFunType (Scalar (Arrow _ _ d t1 (RetType _ t2))) = +unfoldFunType :: TypeBase dim as -> ([(PName, TypeBase dim Diet)], TypeBase dim NoUniqueness) +unfoldFunType (Scalar (Arrow _ p d t1 (RetType _ t2))) = let (ps, r) = unfoldFunType t2 - in (second (const d) t1 : ps, r) + in ((p, second (const d) t1) : ps, r) unfoldFunType t = ([], toStruct t) +-- | Extract the parameter types and 'RetTypeBase' from a function type. +-- If the type is not an arrow type, returns 'Nothing'. +unfoldFunTypeWithRet :: + TypeBase dim as -> + Maybe ([(PName, TypeBase dim Diet)], RetTypeBase dim Uniqueness) +unfoldFunTypeWithRet (Scalar (Arrow _ p d t1 (RetType _ t2@(Scalar Arrow {})))) = do + (ps, r) <- unfoldFunTypeWithRet t2 + pure ((p, second (const d) t1) : ps, r) +unfoldFunTypeWithRet (Scalar (Arrow _ p d t1 r@RetType {})) = + Just ([(p, second (const d) t1)], r) +unfoldFunTypeWithRet _ = Nothing + -- | The type scheme of a value binding, comprising the type -- parameters and the actual type. valBindTypeScheme :: ValBindBase Info VName -> ([TypeParamBase VName], StructType) @@ -611,7 +638,7 @@ patternStructType = toStruct . patternType -- | When viewed as a function parameter, does this pattern correspond -- to a named parameter of some type? -patternParam :: Pat ParamType -> (PName, Diet, StructType) +patternParam :: Pat (TypeBase d Diet) -> (PName, Diet, TypeBase d NoUniqueness) patternParam (PatParens p _) = patternParam p patternParam (PatAttr _ p _) = @@ -684,8 +711,8 @@ mkBinOp op t x y = ( BinOp (qualName (intrinsicVar op), mempty) (Info t) - (x, Info Nothing) - (y, Info Nothing) + (x, Info (Nothing, mempty)) + (y, Info (Nothing, mempty)) mempty ) (Info $ AppRes t []) @@ -840,16 +867,6 @@ intrinsics = $ array_a Unique $ shape [m, k, l] ), - ( "map", - IntrinsicPolyFun - [tp_a, tp_b, sp_n] - [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), - array_a Observe $ shape [n] - ] - $ RetType [] - $ array_b Unique - $ shape [n] - ), ( "reduce", IntrinsicPolyFun [tp_a, sp_n] @@ -1478,6 +1495,23 @@ sameExp e1 e2 all (uncurry sameExp) es | otherwise = False +frameOf :: Exp -> Shape Size +frameOf (AppExp (Apply _ args _) _) = + ((\(_, am) -> autoFrame am) . unInfo . fst) $ NE.last args +frameOf _ = mempty + +-- | @s1 `shapePrefix` s2@ assumes @s1 = prefix <> s2@ and +-- returns @prefix@. +shapePrefix :: Shape dim -> Shape dim -> Shape dim +shapePrefix (Shape ss1) (Shape ss2) = + Shape $ take (length ss1 - length ss2) ss1 + +typeShapePrefix :: TypeBase dim as1 -> TypeBase dim as2 -> Shape dim +typeShapePrefix (Array _ s _) Scalar {} = s +typeShapePrefix (Array _ s1 _) (Array _ s2 _) = + s1 `shapePrefix` s2 +typeShapePrefix _ _ = mempty + -- | An identifier with type- and aliasing information. type Ident = IdentBase Info VName diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index 22bb54366c..7fea0194af 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -23,6 +23,10 @@ module Language.Futhark.Syntax Shape (..), shapeRank, stripDims, + AutoMap (..), + autoRepRank, + autoMapRank, + autoFrameRank, TypeBase (..), TypeArg (..), SizeExp (..), @@ -263,6 +267,28 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing +data AutoMap = AutoMap + { autoRep :: Shape Size, + autoMap :: Shape Size, + autoFrame :: Shape Size + } + deriving (Eq, Show, Ord) + +autoRepRank :: AutoMap -> Int +autoRepRank = shapeRank . autoRep + +autoMapRank :: AutoMap -> Int +autoMapRank = shapeRank . autoMap + +autoFrameRank :: AutoMap -> Int +autoFrameRank = shapeRank . autoFrame + +instance Semigroup AutoMap where + (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) + +instance Monoid AutoMap where + mempty = AutoMap mempty mempty mempty + -- | The name (if any) of a function parameter. The 'Eq' and 'Ord' -- instances always compare values of this type equal. data PName = Named VName | Unnamed @@ -699,7 +725,7 @@ data AppExpBase f vn -- identical). Apply (ExpBase f vn) - (NE.NonEmpty (f (Maybe VName), ExpBase f vn)) + (NE.NonEmpty (f (Maybe VName, AutoMap), ExpBase f vn)) SrcLoc | Range (ExpBase f vn) @@ -733,8 +759,8 @@ data AppExpBase f vn | BinOp (QualName vn, SrcLoc) (f StructType) - (ExpBase f vn, f (Maybe VName)) - (ExpBase f vn, f (Maybe VName)) + (ExpBase f vn, f (Maybe VName, AutoMap)) + (ExpBase f vn, f (Maybe VName, AutoMap)) SrcLoc | LetWith (IdentBase f vn StructType) @@ -843,7 +869,7 @@ data ExpBase f vn (QualName vn) (f StructType) (ExpBase f vn) - (f (PName, ParamType, Maybe VName), f (PName, ParamType)) + (f (PName, ParamType, Maybe VName, AutoMap), f (PName, ParamType)) (f ResRetType, f [VName]) SrcLoc | -- | @+2@; first type is operand, second is result. @@ -851,7 +877,7 @@ data ExpBase f vn (QualName vn) (f StructType) (ExpBase f vn) - (f (PName, ParamType), f (PName, ParamType, Maybe VName)) + (f (PName, ParamType), f (PName, ParamType, Maybe VName, AutoMap)) (f ResRetType) SrcLoc | -- | Field projection as a section: @(.x.y.z)@. @@ -1356,7 +1382,7 @@ deriving instance Show (ProgBase Info VName) deriving instance Show (ProgBase NoInfo Name) -- | Construct an 'Apply' node, with type information. -mkApply :: ExpBase Info vn -> [(Maybe VName, ExpBase Info vn)] -> AppRes -> ExpBase Info vn +mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap, ExpBase Info vn)] -> AppRes -> ExpBase Info vn mkApply f args (AppRes t ext) | Just args' <- NE.nonEmpty $ map onArg args = case f of @@ -1368,7 +1394,7 @@ mkApply f args (AppRes t ext) AppExp (Apply f args' (srcspan f $ snd $ NE.last args')) (Info (AppRes t ext)) | otherwise = f where - onArg (v, x) = (Info v, x) + onArg (v, am, x) = (Info (v, am), x) -- | Construct an 'Apply' node, without type information. mkApplyUT :: ExpBase NoInfo vn -> ExpBase NoInfo vn -> ExpBase NoInfo vn diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index cea5be0d3d..ec8304904b 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -61,6 +61,13 @@ class ASTMappable x where -- into subexpressions. The mapping is done left-to-right. astMap :: (Monad m) => ASTMapper m -> x -> m x +mapOnAutoMap :: (Monad m) => ASTMapper m -> AutoMap -> m AutoMap +mapOnAutoMap tv (AutoMap r m f) = + AutoMap + <$> traverse (mapOnExp tv) r + <*> traverse (mapOnExp tv) m + <*> traverse (mapOnExp tv) f + instance ASTMappable (AppExpBase Info VName) where astMap tv (Range start next end loc) = Range @@ -74,7 +81,7 @@ instance ASTMappable (AppExpBase Info VName) where Match <$> mapOnExp tv e <*> astMap tv cases <*> pure loc astMap tv (Apply f args loc) = do f' <- mapOnExp tv f - args' <- traverse (traverse $ mapOnExp tv) args + args' <- traverse onArg args -- Safe to disregard return type because existentials cannot be -- instantiated here, as the return is necessarily a function. pure $ case f' of @@ -82,6 +89,9 @@ instance ASTMappable (AppExpBase Info VName) where Apply f_inner (args_inner <> args') loc _ -> Apply f' args' loc + where + onArg (Info (ext, am), e) = + (,) <$> (Info . (ext,) <$> mapOnAutoMap tv am) <*> mapOnExp tv e astMap tv (LetPat sizes pat e body loc) = LetPat sizes <$> astMap tv pat <*> mapOnExp tv e <*> mapOnExp tv body <*> pure loc astMap tv (LetFun name (tparams, params, ret, t, e) body loc) = @@ -102,13 +112,16 @@ instance ASTMappable (AppExpBase Info VName) where <*> mapOnExp tv vexp <*> mapOnExp tv body <*> pure loc - astMap tv (BinOp (fname, fname_loc) t (x, xext) (y, yext) loc) = + astMap tv (BinOp (fname, fname_loc) t x y loc) = BinOp <$> ((,) <$> mapOnName tv fname <*> pure fname_loc) <*> traverse (mapOnStructType tv) t - <*> ((,) <$> mapOnExp tv x <*> pure xext) - <*> ((,) <$> mapOnExp tv y <*> pure yext) + <*> onArg x + <*> onArg y <*> pure loc + where + onArg (e, Info (ext, am)) = + (,) <$> mapOnExp tv e <*> (Info . (ext,) <$> mapOnAutoMap tv am) astMap tv (Loop sparams mergepat loopinit form loopbody loc) = Loop sparams <$> astMap tv mergepat @@ -187,25 +200,25 @@ instance ASTMappable (ExpBase Info VName) where <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> pure loc - astMap tv (OpSectionLeft name t arg (Info (pa, t1a, argext), Info (pb, t1b)) (ret, retext) loc) = + astMap tv (OpSectionLeft name t arg (Info (pa, t1a, argext, am), Info (pb, t1b)) (ret, retext) loc) = OpSectionLeft <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> mapOnExp tv arg <*> ( (,) - <$> (Info <$> ((pa,,) <$> mapOnParamType tv t1a <*> pure argext)) + <$> (Info <$> ((pa,,,) <$> mapOnParamType tv t1a <*> pure argext <*> pure am)) <*> (Info <$> ((pb,) <$> mapOnParamType tv t1b)) ) <*> ((,) <$> traverse (mapOnResRetType tv) ret <*> pure retext) <*> pure loc - astMap tv (OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext)) t2 loc) = + astMap tv (OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext, am)) t2 loc) = OpSectionRight <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> mapOnExp tv arg <*> ( (,) <$> (Info <$> ((pa,) <$> mapOnParamType tv t1a)) - <*> (Info <$> ((pb,,) <$> mapOnParamType tv t1b <*> pure argext)) + <*> (Info <$> ((pb,,,) <$> mapOnParamType tv t1b <*> pure argext <*> pure am)) ) <*> traverse (mapOnResRetType tv) t2 <*> pure loc diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 38fbb13147..f87e280330 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -27,6 +27,7 @@ import Data.Maybe import Data.Ord import Data.Set qualified as S import Futhark.FreshNames hiding (newName) +import Futhark.Util (debugTraceM) import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Semantic @@ -689,7 +690,7 @@ checkEntryPoint loc tparams params maybe_tdecl rettype where (RetType _ rettype_t) = rettype (rettype_params, rettype') = unfoldFunType rettype_t - param_ts = map patternType params ++ rettype_params + param_ts = map patternType params ++ map snd rettype_params checkValBind :: ValBindBase NoInfo Name -> TypeM (Env, ValBind) checkValBind vb = do @@ -707,11 +708,14 @@ checkValBind vb = do checkFunDef (fname, maybe_tdecl, tparams, params, body, loc) let entry' = Info (entryPoint params' maybe_tdecl' rettype) <$ entry + vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc + + debugTraceM 3 $ unlines ["# Inferred:", prettyString vb'] + case entry' of Just _ -> checkEntryPoint loc tparams' params' maybe_tdecl' rettype _ -> pure () - let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc pure ( mempty { envVtable = diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs new file mode 100644 index 0000000000..c8ceb92763 --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -0,0 +1,782 @@ +module Language.Futhark.TypeChecker.Constraints + ( Reason (..), + SVar, + SComp (..), + Type, + toType, + Ct (..), + Constraints, + TyVarInfo (..), + TyVar, + TyVars, + TyParams, + Solution, + UnconTyVar, + solve, + ) +where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.State +import Data.Bifunctor +import Data.List qualified as L +import Data.Loc +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Futhark.Util.Pretty +import Language.Futhark +import Language.Futhark.TypeChecker.Error +import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote) +import Language.Futhark.TypeChecker.Types (substTyVars) + +type SVar = VName + +-- | A shape component. `SDim` is a single dimension of unspecified +-- size, `SVar` is a shape variable. A list of shape components should +-- then be understood as concatenation of shapes (meaning you can't +-- just take the length to determine the rank of the array). +data SComp + = SDim + | SVar SVar + deriving (Eq, Ord, Show) + +instance Pretty SComp where + pretty SDim = "[]" + pretty (SVar x) = brackets $ prettyName x + +instance Pretty (Shape SComp) where + pretty = mconcat . map pretty . shapeDims + +-- | The type representation used by the constraint solver. Agnostic +-- to sizes. +type Type = TypeBase SComp NoUniqueness + +-- | Careful when using this on something that already has an SComp +-- size: it will throw away information by converting them to SDim. +toType :: TypeBase Size u -> TypeBase SComp u +toType = first (const SDim) + +-- | The reason for a type constraint. Used to generate type error +-- messages. The expected type is always the first one. +data Reason + = -- | No particular reason. + Reason Loc + | -- | Arising from pattern match. + ReasonPatMatch Loc (PatBase NoInfo VName ParamType) Type + | -- | Arising from explicit ascription. + ReasonAscription Loc Type Type + | ReasonRetType Loc Type Type + | ReasonApply Loc (Maybe (QualName VName)) Exp Type Type + | ReasonBranches Loc Type Type + deriving (Eq, Show) + +instance Located Reason where + locOf (Reason l) = l + locOf (ReasonPatMatch l _ _) = l + locOf (ReasonAscription l _ _) = l + locOf (ReasonRetType l _ _) = l + locOf (ReasonApply l _ _ _ _) = l + locOf (ReasonBranches l _ _) = l + +data Ct + = CtEq Reason Type Type + | CtAM Reason SVar SVar (Shape SComp) + deriving (Show) + +ctReason :: Ct -> Reason +ctReason (CtEq r _ _) = r +ctReason (CtAM r _ _ _) = r + +instance Located Ct where + locOf = locOf . ctReason + +instance Pretty Ct where + pretty (CtEq _ t1 t2) = pretty t1 <+> "~" <+> pretty t2 + pretty (CtAM _ r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" + +type Constraints = [Ct] + +-- | Information about a flexible type variable. Every type variable +-- is associated with a location, which is the original syntax element +-- that it is the type of. +data TyVarInfo + = -- | Can be substituted with anything. + TyVarFree Loc Liftedness + | -- | Can only be substituted with these primitive types. + TyVarPrim Loc [PrimType] + | -- | Must be a record with these fields. + TyVarRecord Loc (M.Map Name Type) + | -- | Must be a sum type with these fields. + TyVarSum Loc (M.Map Name [Type]) + | -- | Must be a type that supports equality. + TyVarEql Loc + deriving (Show, Eq) + +instance Pretty TyVarInfo where + pretty (TyVarFree _ l) = "free" <+> pretty l + pretty (TyVarPrim _ pts) = "∈" <+> pretty pts + pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs + pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs + pretty (TyVarEql _) = "equality" + +instance Located TyVarInfo where + locOf (TyVarFree loc _) = loc + locOf (TyVarPrim loc _) = loc + locOf (TyVarRecord loc _) = loc + locOf (TyVarSum loc _) = loc + locOf (TyVarEql loc) = loc + +type TyVar = VName + +-- | The level at which a type variable is bound. Higher means +-- deeper. We can only unify a type variable at level @i@ with a type +-- @t@ if all type names that occur in @t@ are at most at level @i@. +type Level = Int + +-- | If a VName is not in this map, it should be in the 'TyParams' - +-- the exception is abstract types, which are just missing (and +-- assumed to have smallest possible level). +type TyVars = M.Map TyVar (Level, TyVarInfo) + +-- | Explicit type parameters. +type TyParams = M.Map TyVar (Level, Liftedness, Loc) + +data TyVarSol + = -- | Has been substituted with this. + TyVarSol Type + | -- | Is an explicit (rigid) type parameter in the source program. + TyVarParam Level Liftedness Loc + | -- | Not substituted yet; has this constraint. + TyVarUnsol TyVarInfo + deriving (Show) + +newtype SolverState = SolverState + { -- | Left means linked to this other type variable. + solverTyVars :: M.Map TyVar (Either VName TyVarSol) + } + +initialState :: TyParams -> TyVars -> SolverState +initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars + where + f (_lvl, info) = Right $ TyVarUnsol info + g (lvl, l, loc) = Right $ TyVarParam lvl l loc + +substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) +substTyVar m v = + case M.lookup v m of + Just (Left v') -> substTyVar m v' + Just (Right (TyVarSol t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (Right TyVarParam {}) -> Nothing + Just (Right (TyVarUnsol {})) -> Nothing + Nothing -> Nothing + +maybeLookupTyVar :: TyVar -> SolveM (Maybe TyVarSol) +maybeLookupTyVar orig = do + tyvars <- gets solverTyVars + let f v = case M.lookup v tyvars of + Nothing -> pure Nothing + Just (Left v') -> f v' + Just (Right info) -> pure $ Just info + f orig + +lookupTyVar :: TyVar -> SolveM (Either TyVarInfo Type) +lookupTyVar orig = + maybe bad unpack <$> maybeLookupTyVar orig + where + bad = error $ "Unknown tyvar: " <> prettyNameString orig + unpack (TyVarParam {}) = error $ "Is a type param: " <> prettyNameString orig + unpack (TyVarSol t) = Right t + unpack (TyVarUnsol info) = Left info + +-- | Variable must be flexible. +lookupTyVarInfo :: TyVar -> SolveM TyVarInfo +lookupTyVarInfo v = do + r <- lookupTyVar v + case r of + Left info -> pure info + Right _ -> error $ "Tyvar is nonflexible: " <> prettyNameString v + +setLink :: TyVar -> VName -> SolveM () +setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} + +setInfo :: TyVar -> TyVarSol -> SolveM () +setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} + +-- | A solution maps a type variable to its substitution. This +-- substitution is complete, in the sense there are no right-hand +-- sides that contain a type variable. +type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) + +-- | An unconstrained type variable comprises a name and (ironically) +-- a constraint on how it can be instantiated. +type UnconTyVar = (VName, Liftedness) + +solution :: SolverState -> ([UnconTyVar], Solution) +solution s = + ( mapMaybe unconstrained $ M.toList $ solverTyVars s, + M.mapMaybe mkSubst $ solverTyVars s + ) + where + mkSubst (Right (TyVarSol t)) = + Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t + mkSubst (Left v') = + Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ + mkSubst =<< M.lookup v' (solverTyVars s) + mkSubst (Right (TyVarUnsol (TyVarPrim _ pts))) = Just $ Left pts + mkSubst _ = Nothing + + unconstrained (v, Right (TyVarUnsol (TyVarFree _ l))) = Just (v, l) + unconstrained _ = Nothing + +newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} + deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) + +-- Try to substitute as much information as we have. +enrichType :: Type -> SolveM Type +enrichType t = do + s <- get + pure $ substTyVars (substTyVar (solverTyVars s)) t + +typeError :: Loc -> Notes -> Doc () -> SolveM () +typeError loc notes msg = + throwError $ TypeError loc notes msg + +occursCheck :: Reason -> VName -> Type -> SolveM () +occursCheck reason v tp = do + vars <- gets solverTyVars + let tp' = substTyVars (substTyVar vars) tp + when (v `S.member` typeVars tp') . typeError (locOf reason) mempty $ + "Occurs check: cannot instantiate" + <+> prettyName v + <+> "with" + <+> pretty tp + <> "." + +unifySharedConstructors :: + Reason -> + BreadCrumbs -> + M.Map Name [Type] -> + M.Map Name [Type] -> + SolveM () +unifySharedConstructors reason bcs cs1 cs2 = + forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> + if length ts1 == length ts2 + then zipWithM_ (solveEq reason $ matchingConstructor c <> bcs) ts1 ts2 + else + typeError (locOf reason) mempty $ + "Cannot unify type with constructor" + indent 2 (pretty (Sum (M.singleton c ts1))) + "with type of constructor" + indent 2 (pretty (Sum (M.singleton c ts2))) + "because they differ in arity." + +unifySharedFields :: + Reason -> + BreadCrumbs -> + M.Map Name Type -> + M.Map Name Type -> + SolveM () +unifySharedFields reason bcs fs1 fs2 = + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> + solveEq reason (matchingField f <> bcs) ts1 ts2 + +mustSupportEql :: Reason -> Type -> SolveM () +mustSupportEql _reason _t = pure () + +scopeViolation :: Reason -> VName -> Type -> VName -> SolveM () +scopeViolation reason v1 ty v2 = + typeError (locOf reason) mempty $ + "Cannot unify type" + indent 2 (pretty ty) + "with" + <+> dquotes (prettyName v1) + <+> "(scope violation)." + "This is because" + <+> dquotes (prettyName v2) + <+> "is rigidly bound in a deeper scope." + +cannotUnify :: + Reason -> + Notes -> + BreadCrumbs -> + Type -> + Type -> + SolveM () +cannotUnify reason notes bcs t1 t2 = do + t1' <- enrichType t1 + t2' <- enrichType t2 + case reason of + Reason loc -> + typeError loc notes . stack $ + [ "Cannot unify", + indent 2 (pretty t1'), + "with", + indent 2 (pretty t2') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonPatMatch loc pat value_t -> + typeError loc notes . stack $ + [ "Pattern", + indent 2 $ align $ pretty pat, + "cannot match value of type", + indent 2 $ align $ pretty value_t + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonAscription loc expected actual -> + typeError loc notes . stack $ + [ "Expression does not have expected type from type ascription.", + "Expected:" <+> align (pretty expected), + "Actual: " <+> align (pretty actual) + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonRetType loc expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual + typeError loc notes . stack $ + [ "Function body does not have expected type.", + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonApply loc f e expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual + typeError loc notes . stack $ + [ header, + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') + ] + where + header = + case f of + Nothing -> + "Cannot apply function to" + <+> dquotes (shorten $ group $ pretty e) + <> " (invalid type)." + Just fname -> + "Cannot apply" + <+> dquotes (pretty fname) + <+> "to" + <+> dquotes (align $ shorten $ group $ pretty e) + <> " (invalid type)." + ReasonBranches loc former latter -> do + former' <- enrichType former + latter' <- enrichType latter + typeError loc notes . stack $ + [ "Branches differ in type.", + "Former:" <+> pretty former', + "Latter:" <+> pretty latter' + ] + +-- Precondition: 'v' is currently flexible. +subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () +subTyVar reason bcs v t = do + occursCheck reason v t + v_info <- gets $ M.lookup v . solverTyVars + + -- Set a solution for v, then update info for t in case v has any + -- odd constraints. + setInfo v (TyVarSol t) + + case (v_info, t) of + (Just (Right (TyVarUnsol TyVarFree {})), _) -> + pure () + ( Just (Right (TyVarUnsol (TyVarPrim _ v_pts))), + _ + ) -> + if t `elem` map (Scalar . Prim) v_pts + then pure () + else + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with" + indent 2 (pretty t) + ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), + Scalar (Sum cs2) + ) -> + if all (`elem` M.keys cs2) (M.keys cs1) + then unifySharedConstructors reason bcs cs1 cs2 + else + typeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty (Sum cs2)) + ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), + _ + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty t) + ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), + Scalar (Record fs2) + ) -> + if all (`elem` M.keys fs2) (M.keys fs1) + then unifySharedFields reason bcs fs1 fs2 + else + typeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with record type" + indent 2 (pretty (Record fs2)) + ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), + _ + ) -> + typeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty t) + (Just (Right (TyVarUnsol (TyVarEql _))), _) -> + mustSupportEql reason t + -- + -- Internal error cases + (Just (Right TyVarSol {}), _) -> + error $ "Type variable already solved: " <> prettyNameString v + (Just (Right TyVarParam {}), _) -> + error $ "Cannot substitute type parameter: " <> prettyNameString v + (Just Left {}, _) -> + error $ "Type variable already linked: " <> prettyNameString v + (Nothing, _) -> + error $ "subTyVar: Nothing v: " <> prettyNameString v + +-- Precondition: 'v' and 't' are both currently flexible. +-- +-- The purpose of this function is to combine the partial knowledge we +-- may have about these two type variables. +unionTyVars :: Reason -> BreadCrumbs -> VName -> VName -> SolveM () +unionTyVars reason bcs v t = do + v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars + t_info <- lookupTyVarInfo t + + -- Insert the link from v to t, and then update the info of t based + -- on the existing info of v and t. + setLink v t + + case (v_info, t_info) of + ( TyVarUnsol (TyVarFree _ v_l), + TyVarFree t_loc t_l + ) + | v_l /= t_l -> + setInfo t $ TyVarUnsol $ TyVarFree t_loc (min v_l t_l) + -- When either is completely unconstrained. + (TyVarUnsol TyVarFree {}, _) -> + pure () + ( TyVarUnsol info, + TyVarFree {} + ) -> + setInfo t (TyVarUnsol info) + -- + -- TyVarPrim cases + ( TyVarUnsol info@TyVarPrim {}, + TyVarEql {} + ) -> + setInfo t (TyVarUnsol info) + ( TyVarUnsol (TyVarPrim _ v_pts), + TyVarPrim t_loc t_pts + ) -> + let pts = L.intersect v_pts t_pts + in if null pts + then + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be one of" + indent 2 (pretty t_pts) + else setInfo t (TyVarUnsol (TyVarPrim t_loc pts)) + ( TyVarUnsol (TyVarPrim _ v_pts), + TyVarRecord {} + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be record." + ( TyVarUnsol (TyVarPrim _ v_pts), + TyVarSum {} + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be sum." + -- + -- TyVarSum cases + ( TyVarUnsol (TyVarSum _ cs1), + TyVarSum loc cs2 + ) -> do + unifySharedConstructors reason bcs cs1 cs2 + let cs3 = cs1 <> cs2 + setInfo t (TyVarUnsol (TyVarSum loc cs3)) + ( TyVarUnsol TyVarSum {}, + TyVarPrim _ pts + ) -> + typeError (locOf reason) mempty $ + "A sum type cannot be one of" + indent 2 (pretty pts) + ( TyVarUnsol (TyVarSum _ cs1), + TyVarRecord _ fs + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty (Scalar (Record fs))) + ( TyVarUnsol (TyVarSum _ cs1), + TyVarEql _ + ) -> + mapM_ (mapM_ (mustSupportEql reason)) cs1 + -- + -- TyVarRecord cases + ( TyVarUnsol (TyVarRecord _ fs1), + TyVarRecord loc fs2 + ) -> do + unifySharedFields reason bcs fs1 fs2 + let fs3 = fs1 <> fs2 + setInfo t (TyVarUnsol (TyVarRecord loc fs3)) + ( TyVarUnsol TyVarRecord {}, + TyVarPrim _ pts + ) -> + typeError (locOf reason) mempty $ + "A record type cannot be one of" + indent 2 (pretty pts) + ( TyVarUnsol (TyVarRecord _ fs1), + TyVarSum _ cs + ) -> + typeError (locOf reason) mempty $ + "Cannot unify record type" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty (Scalar (Sum cs))) + ( TyVarUnsol (TyVarRecord _ fs1), + TyVarEql _ + ) -> + mapM_ (mustSupportEql reason) fs1 + -- + -- TyVarEql cases + (TyVarUnsol (TyVarEql _), TyVarPrim {}) -> + pure () + (TyVarUnsol (TyVarEql _), TyVarEql {}) -> + pure () + (TyVarUnsol (TyVarEql _), TyVarRecord _ fs) -> + mustSupportEql reason $ Scalar $ Record fs + (TyVarUnsol (TyVarEql _), TyVarSum _ cs) -> + mustSupportEql reason $ Scalar $ Sum cs + -- + -- Internal error cases + (TyVarSol {}, _) -> + alreadySolved + (TyVarParam {}, _) -> + isParam + where + unknown = error $ "unionTyVars: Nothing v: " <> prettyNameString v + alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v + alreadySolved = error $ "Type variable already solved: " <> prettyNameString v + isParam = error $ "Type name is a type parameter: " <> prettyNameString v + +-- Unify at the root, emitting new equalities that must hold. +unify :: Type -> Type -> Either (Doc a) [(BreadCrumbs, (Type, Type))] +unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) + | pt1 == pt2 = Right [] +unify + (Scalar (TypeVar _ (QualName _ v1) targs1)) + (Scalar (TypeVar _ (QualName _ v2) targs2)) + | v1 == v2 = + Right $ mapMaybe f $ zip targs1 targs2 + where + f (TypeArgType t1, TypeArgType t2) = Just (mempty, (t1, t2)) + f _ = Nothing +unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = + Right [(mempty, (t1a, t2a)), (mempty, (t1r', t2r'))] + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness +unify (Scalar (Record fs1)) (Scalar (Record fs2)) + | M.keys fs1 == M.keys fs2 = + Right $ + map (first matchingField) $ + M.toList $ + M.intersectionWith (,) fs1 fs2 + | Just n1 <- length <$> areTupleFields fs1, + Just n2 <- length <$> areTupleFields fs2, + n1 /= n2 = + Left $ + "Tuples have" + <+> pretty n1 + <+> "and" + <+> pretty n2 + <+> "elements respectively." + | otherwise = + let missing = + filter (`notElem` M.keys fs1) (M.keys fs2) + <> filter (`notElem` M.keys fs2) (M.keys fs1) + in Left $ + "Unshared fields:" <+> commasep (map pretty missing) <> "." +unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) + | M.keys cs1 == M.keys cs2 = + fmap concat . forM cs' $ \(c, (ts1, ts2)) -> do + if length ts1 == length ts2 + then Right $ zipWith (curry (matchingConstructor c,)) ts1 ts2 + else Left mempty + where + cs' = M.toList $ M.intersectionWith (,) cs1 cs2 +unify t1 t2 + | Just t1' <- peelArray 1 t1, + Just t2' <- peelArray 1 t2 = + Right [(mempty, (t1', t2'))] +unify _ _ = Left mempty + +solveEq :: Reason -> BreadCrumbs -> Type -> Type -> SolveM () +solveEq reason obcs orig_t1 orig_t2 = do + solveCt' (obcs, (orig_t1, orig_t2)) + where + solveCt' (bcs, (t1, t2)) = do + tyvars <- gets solverTyVars + let flexible v = case M.lookup v tyvars of + Just (Left v') -> flexible v' + Just (Right (TyVarUnsol _)) -> True + Just (Right TyVarSol {}) -> False + Just (Right TyVarParam {}) -> False + Nothing -> False + sub t@(Scalar (TypeVar u (QualName [] v) [])) = + case M.lookup v tyvars of + Just (Left v') -> sub $ Scalar (TypeVar u (QualName [] v') []) + Just (Right (TyVarSol t')) -> sub t' + _ -> t + sub t = t + case (sub t1, sub t2) of + ( t1'@(Scalar (TypeVar _ (QualName [] v1) [])), + t2'@(Scalar (TypeVar _ (QualName [] v2) [])) + ) + | v1 == v2 -> pure () + | otherwise -> + case (flexible v1, flexible v2) of + (False, False) -> cannotUnify reason mempty bcs t1 t2 + (True, False) -> subTyVar reason bcs v1 t2' + (False, True) -> subTyVar reason bcs v2 t1' + (True, True) -> unionTyVars reason bcs v1 v2 + (Scalar (TypeVar _ (QualName [] v1) []), t2') + | flexible v1 -> subTyVar reason bcs v1 t2' + (t1', Scalar (TypeVar _ (QualName [] v2) [])) + | flexible v2 -> subTyVar reason bcs v2 t1' + (t1', t2') -> case unify t1' t2' of + Left details -> cannotUnify reason (aNote details) bcs t1' t2' + Right eqs -> mapM_ solveCt' eqs + +solveCt :: Ct -> SolveM () +solveCt ct = + case ct of + CtEq reason t1 t2 -> solveEq reason mempty t1 t2 + CtAM {} -> pure () -- Good vibes only. + +scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () +scopeCheck reason v v_lvl ty = do + mapM_ check $ typeVars ty + where + check ty_v = do + ty_v_info <- gets $ M.lookup ty_v . solverTyVars + case ty_v_info of + Just (Right (TyVarParam ty_v_lvl _ _)) + | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v + _ -> pure () + +-- If a type variable has a liftedness constraint, we propagate that +-- constraint to its solution. The actual checking for correct usage +-- is done later. +liftednessCheck :: Liftedness -> Type -> SolveM () +liftednessCheck l (Scalar (TypeVar _ (QualName [] v) _)) = do + v_info <- maybeLookupTyVar v + case v_info of + Nothing -> + -- Is an opaque type. + pure () + Just (TyVarSol v_ty) -> + liftednessCheck l v_ty + Just TyVarParam {} -> pure () + Just (TyVarUnsol (TyVarFree loc v_l)) + | l /= v_l -> + setInfo v $ TyVarUnsol $ TyVarFree loc (min l v_l) + Just TyVarUnsol {} -> pure () +liftednessCheck _ (Scalar Prim {}) = pure () +liftednessCheck Lifted _ = pure () +liftednessCheck _ Array {} = pure () +liftednessCheck _ (Scalar Arrow {}) = pure () +liftednessCheck l (Scalar (Record fs)) = + mapM_ (liftednessCheck l) fs +liftednessCheck l (Scalar (Sum cs)) = + mapM_ (mapM_ $ liftednessCheck l) cs +liftednessCheck _ (Scalar TypeVar {}) = pure () + +solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () +solveTyVar (tv, (_, TyVarRecord loc fs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Left _ -> + typeError loc mempty $ + "Type" + <+> prettyName tv + <+> "is ambiguous." + "Must be a record with fields" + indent 2 (pretty (Scalar (Record fs1))) + Right _ -> + pure () +solveTyVar (tv, (_, TyVarSum loc cs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Left _ -> + typeError loc mempty $ + "Type is ambiguous." + "Must be a sum type with constructors" + indent 2 (pretty (Scalar (Sum cs1))) + Right _ -> pure () +solveTyVar (tv, (_, TyVarEql loc)) = do + tv_t <- lookupTyVar tv + case tv_t of + Left TyVarEql {} -> + typeError loc mempty $ + "Type is ambiguous (must be equality type)" + "Add a type annotation to disambiguate the type." + Left _ -> pure () + Right ty + | orderZero ty -> pure () + | otherwise -> + typeError loc mempty $ + "Type" + indent 2 (align (pretty ty)) + "does not support equality (may contain function)." +solveTyVar (tv, (lvl, TyVarFree loc l)) = do + tv_t <- lookupTyVar tv + case tv_t of + Right ty -> do + scopeCheck (Reason loc) tv lvl ty + liftednessCheck l ty + _ -> pure () +solveTyVar (tv, (_, TyVarPrim loc pts)) = do + tv_t <- lookupTyVar tv + case tv_t of + Right ty + | ty `elem` map (Scalar . Prim) pts -> pure () + | otherwise -> + typeError loc mempty $ + "Numeric constant inferred to be of type" + indent 2 (align (pretty ty)) + "which is not possible." + _ -> pure () + +solve :: + Constraints -> + TyParams -> + TyVars -> + Either TypeError ([UnconTyVar], Solution) +solve constraints typarams tyvars = + second solution + . runExcept + . flip execStateT (initialState typarams tyvars) + . runSolveM + $ do + mapM_ solveCt constraints + mapM_ solveTyVar (M.toList tyvars) +{-# NOINLINE solve #-} diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index 672b4135fa..2b634e0c52 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -486,9 +486,10 @@ consumeAsNeeded loc (Scalar (Record fs1)) (Scalar (Record fs2)) = consumeAsNeeded loc pt t = when (diet pt == Consume) $ consumeAliases loc $ aliases t -checkArg :: [(Exp, TypeAliases)] -> ParamType -> Exp -> CheckM (Exp, TypeAliases) -checkArg prev p_t e = do - ((e', e_als), e_cons) <- contain $ checkExp e +checkArg :: [(Exp, TypeAliases)] -> ParamType -> AutoMap -> Exp -> CheckM (Exp, TypeAliases) +checkArg prev p_t am e = do + ((e', e_als), e_cons) <- + contain $ if autoRep am /= mempty then noAliases e else checkExp e consumed e_cons let e_t = typeOf e' when (e_cons /= mempty && not (orderZero e_t)) $ @@ -542,9 +543,11 @@ returnType appres (Scalar (Arrow _ v pd t1 (RetType dims t2))) Observe arg = returnType appres (Scalar (Sum cs)) d arg = Scalar $ Sum $ (fmap . fmap) (\et -> returnType appres et d arg) cs -applyArg :: TypeAliases -> TypeAliases -> TypeAliases -applyArg (Scalar (Arrow closure_als _ d _ (RetType _ rettype))) arg_als = - returnType closure_als rettype d arg_als +applyArg :: TypeAliases -> (AutoMap, TypeAliases) -> TypeAliases +applyArg (Scalar (Arrow closure_als _ d _ (RetType _ rettype))) (am, arg_als) = + if autoMap am /= mempty + then second (const mempty) rettype + else returnType closure_als rettype d arg_als applyArg t _ = error $ "applyArg: " <> show t boundFreeInExp :: Exp -> CheckM (M.Map VName TypeAliases) @@ -669,9 +672,9 @@ checkLoop loop_loc (param, arg, form, body) = do let param_t = patternType param' ((arg', arg_als), arg_cons) <- case arg of LoopInitImplicit (Info e) -> - contain $ first (LoopInitImplicit . Info) <$> checkArg [] param_t e + contain $ first (LoopInitImplicit . Info) <$> checkArg [] param_t mempty e LoopInitExplicit e -> - contain $ first LoopInitExplicit <$> checkArg [] param_t e + contain $ first LoopInitExplicit <$> checkArg [] param_t mempty e consumed arg_cons free_bound <- boundFreeInExp body @@ -692,7 +695,7 @@ checkLoop loop_loc (param, arg, form, body) = do `setAliases` S.singleton (AliasFree v) pure ( (param', arg', form', body'), - applyArg loopt arg_als `combineAliases` body_als + applyArg loopt (mempty, arg_als) `combineAliases` body_als ) checkFuncall :: @@ -700,7 +703,7 @@ checkFuncall :: SrcLoc -> Maybe (QualName VName) -> TypeAliases -> - f TypeAliases -> + f (AutoMap, TypeAliases) -> CheckM TypeAliases checkFuncall loc fname f_als arg_als = do v <- VName "internal_app_result" <$> incCounter @@ -714,15 +717,17 @@ checkExp :: Exp -> CheckM (Exp, TypeAliases) checkExp (AppExp (Apply f args loc) appres) = do (f', f_als) <- checkExp f (args', args_als) <- NE.unzip <$> checkArgs (toRes Nonunique f_als) args - res_als <- checkFuncall loc (fname f) f_als args_als + res_als <- + checkFuncall loc (fname f) f_als $ + NE.zip (fmap (snd . unInfo . fst) args') args_als pure (AppExp (Apply f' args' loc) appres, res_als) where fname (Var v _ _) = Just v fname (AppExp (Apply e _ _) _) = fname e fname _ = Nothing - checkArg' prev d (Info p, e) = do - (e', e_als) <- checkArg prev (second (const d) (typeOf e)) e - pure ((Info p, e'), e_als) + checkArg' prev d (Info (p, am), e) = do + (e', e_als) <- checkArg prev (second (const d) (typeOf e)) am e + pure ((Info (p, am), e'), e_als) checkArgs (Scalar (Arrow _ _ d _ (RetType _ rt))) (x NE.:| args') = do -- Note Futhark uses right-to-left evaluation of applications. @@ -813,10 +818,10 @@ checkExp (AppExp (LetFun fname (typarams, params, te, Info (RetType ext ret), fu -- checkExp (AppExp (BinOp (op, oploc) opt (x, xp) (y, yp) loc) appres) = do op_als <- observeVar (locOf oploc) (qualLeaf op) (unInfo opt) - let at1 : at2 : _ = fst $ unfoldFunType op_als - (x', x_als) <- checkArg [] at1 x - (y', y_als) <- checkArg [(x', x_als)] at2 y - res_als <- checkFuncall loc (Just op) op_als [x_als, y_als] + let (_, at1) : (_, at2) : _ = fst $ unfoldFunType op_als + (x', x_als) <- checkArg [] at1 mempty x + (y', y_als) <- checkArg [(x', x_als)] at2 mempty y + res_als <- checkFuncall loc (Just op) op_als [(mempty, x_als), (mempty, y_als)] pure ( AppExp (BinOp (op, oploc) opt (x', xp) (y', yp) loc) appres, res_als diff --git a/src/Language/Futhark/TypeChecker/Error.hs b/src/Language/Futhark/TypeChecker/Error.hs new file mode 100644 index 0000000000..d4fbc70aad --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Error.hs @@ -0,0 +1,79 @@ +-- | Fundamental facilities for constructing type error messages. +module Language.Futhark.TypeChecker.Error + ( -- * Breadcrumbs + BreadCrumbs, + hasNoBreadCrumbs, + matchingField, + matchingConstructor, + matchingTypes, + matching, + ) +where + +import Futhark.Util.Pretty +import Language.Futhark + +-- | A piece of information that describes what process the type +-- checker currently performing. This is used to give better error +-- messages for unification errors. +data BreadCrumb + = MatchingTypes StructType StructType + | MatchingFields [Name] + | MatchingConstructor Name + | Matching (Doc ()) + +instance Pretty BreadCrumb where + pretty (MatchingTypes t1 t2) = + "When matching type" + indent 2 (pretty t1) + "with" + indent 2 (pretty t2) + pretty (MatchingFields fields) = + "When matching types of record field" + <+> dquotes (mconcat $ punctuate "." $ map pretty fields) + <> dot + pretty (MatchingConstructor c) = + "When matching types of constructor" <+> dquotes (pretty c) <> dot + pretty (Matching s) = + unAnnotate s + +-- | Unification failures can occur deep down inside complicated types +-- (consider nested records). We leave breadcrumbs behind us so we can +-- report the path we took to find the mismatch. When combining +-- breadcrumbs with the 'Semigroup' instance, put the innermost +-- breadcrumbs to the left. +newtype BreadCrumbs = BreadCrumbs [BreadCrumb] + +instance Semigroup BreadCrumbs where + BreadCrumbs (MatchingFields xs : bcs1) <> BreadCrumbs (MatchingFields ys : bcs2) = + BreadCrumbs $ MatchingFields (ys <> xs) : bcs1 <> bcs2 + BreadCrumbs bcs1 <> BreadCrumbs bcs2 = + BreadCrumbs $ bcs1 <> bcs2 + +instance Monoid BreadCrumbs where + mempty = BreadCrumbs [] + +-- | Is the path empty? +hasNoBreadCrumbs :: BreadCrumbs -> Bool +hasNoBreadCrumbs (BreadCrumbs []) = True +hasNoBreadCrumbs _ = False + +-- | Matching a record field. +matchingField :: Name -> BreadCrumbs +matchingField f = BreadCrumbs [MatchingFields [f]] + +-- | Matching two types. +matchingTypes :: StructType -> StructType -> BreadCrumbs +matchingTypes t1 t2 = BreadCrumbs [MatchingTypes t1 t2] + +-- | Matching a constructor. +matchingConstructor :: Name -> BreadCrumbs +matchingConstructor c = BreadCrumbs [MatchingConstructor c] + +-- | Matching anything. +matching :: Doc () -> BreadCrumbs +matching d = BreadCrumbs [Matching d] + +instance Pretty BreadCrumbs where + pretty (BreadCrumbs []) = mempty + pretty (BreadCrumbs bcs) = line <> stack (map pretty bcs) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs new file mode 100644 index 0000000000..24254d7392 --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -0,0 +1,508 @@ +module Language.Futhark.TypeChecker.Rank + ( rankAnalysis, + rankAnalysis1, + ) +where + +import Control.Monad +import Control.Monad.Reader +import Control.Monad.State +import Data.Bifunctor +import Data.Functor.Identity +import Data.List qualified as L +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Futhark.IR.Pretty () +import Futhark.Solve.GLPK +import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) +import Futhark.Solve.LP qualified as LP +import Futhark.Util (debugTraceM) +import Futhark.Util.Pretty +import Language.Futhark hiding (ScalarType) +import Language.Futhark.Traversals +import Language.Futhark.TypeChecker.Constraints +import Language.Futhark.TypeChecker.Monad +import System.IO.Unsafe + +type LSum = LP.LSum VName Int + +type Constraint = LP.Constraint VName Int + +type LinearProg = LP.LinearProg VName Int + +type ScalarType = ScalarTypeBase SComp NoUniqueness + +class Rank a where + rank :: a -> LSum + +instance Rank VName where + rank = var + +instance Rank SComp where + rank SDim = constant 1 + rank (SVar v) = var v + +instance Rank (Shape SComp) where + rank = foldr (\d r -> rank d ~+~ r) (constant 0) . shapeDims + +instance Rank ScalarType where + rank Prim {} = constant 0 + rank (TypeVar _ (QualName [] v) []) = var v + rank (TypeVar {}) = constant 0 + rank (Arrow {}) = constant 0 + rank (Record {}) = constant 0 + rank (Sum {}) = constant 0 + +instance Rank Type where + rank (Scalar t) = rank t + rank (Array _ shape t) = rank shape ~+~ rank t + +distribAndSplitArrows :: Ct -> [Ct] +distribAndSplitArrows (CtEq r t1 t2) = + splitArrows $ CtEq r (distribute t1) (distribute t2) + where + distribute :: TypeBase dim as -> TypeBase dim as + distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute $ arrayOfWithAliases Nonunique s tr) + distribute t = t + + splitArrows + ( CtEq + reason + (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) + ) = + splitArrows (CtEq reason t1a t2a) ++ splitArrows (CtEq reason t1r' t2r') + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness + splitArrows c = [c] +distribAndSplitArrows ct = [ct] + +distribAndSplitCnstrs :: Ct -> [Ct] +distribAndSplitCnstrs ct@(CtEq r t1 t2) = + ct : splitCnstrs (CtEq r (distribute1 t1) (distribute1 t2)) + where + distribute1 :: TypeBase dim as -> TypeBase dim as + distribute1 (Array u s (Record ts1)) = + Scalar $ Record $ fmap (arrayOfWithAliases u s) ts1 + distribute1 (Array u s (Sum cs)) = + Scalar $ Sum $ (fmap . fmap) (arrayOfWithAliases u s) cs + distribute1 t = t + + -- FIXME. Should check for key set equality here. + splitCnstrs (CtEq reason (Scalar (Record ts1)) (Scalar (Record ts2))) = + concat $ zipWith (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems ts1) (M.elems ts2) + splitCnstrs (CtEq reason (Scalar (Sum cs1)) (Scalar (Sum cs2))) = + concat $ concat $ (zipWith . zipWith) (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems cs1) (M.elems cs2) + splitCnstrs _ = [] +distribAndSplitCnstrs ct = [ct] + +data RankState = RankState + { rankBinVars :: Map VName VName, + rankCounter :: !Int, + rankConstraints :: [Constraint], + rankObj :: LSum + } + +newtype RankM a = RankM {runRankM :: State RankState a} + deriving (Functor, Applicative, Monad, MonadState RankState) + +incCounter :: RankM Int +incCounter = do + s <- get + put s {rankCounter = rankCounter s + 1} + pure $ rankCounter s + +binVar :: VName -> RankM VName +binVar sv = do + mbv <- gets ((M.!? sv) . rankBinVars) + case mbv of + Nothing -> do + bv <- VName ("b_" <> baseName sv) <$> incCounter + modify $ \s -> + s + { rankBinVars = M.insert sv bv $ rankBinVars s, + rankConstraints = rankConstraints s ++ [bin bv, var bv ~<=~ var sv] + } + pure bv + Just bv -> pure bv + +addConstraints :: [Constraint] -> RankM () +addConstraints cs = + modify $ \s -> s {rankConstraints = rankConstraints s ++ cs} + +addConstraint :: Constraint -> RankM () +addConstraint = addConstraints . pure + +addObj :: SVar -> RankM () +addObj sv = + modify $ \s -> s {rankObj = rankObj s ~+~ var sv} + +addCt :: Ct -> RankM () +addCt (CtEq _ t1 t2) = addConstraint $ rank t1 ~==~ rank t2 +addCt (CtAM _ r m f) = do + b_r <- binVar r + b_m <- binVar m + b_max <- VName "c_max" <$> incCounter + tr <- VName ("T_" <> baseName r) <$> incCounter + addConstraints [bin b_max, var b_max ~<=~ var tr] + addConstraints $ oneIsZero (b_r, r) (b_m, m) + addConstraints $ LP.max b_max (constant 0) (rank r ~-~ rank f) (var tr) + addObj m + addObj tr + +addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () +addTyVarInfo _ (_, TyVarFree {}) = pure () +addTyVarInfo tv (_, TyVarPrim {}) = + addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo tv (_, TyVarRecord {}) = + addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo tv (_, TyVarSum {}) = + addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo tv (_, TyVarEql {}) = + addConstraint $ rank tv ~==~ constant 0 + +mkLinearProg :: [Ct] -> TyVars -> LinearProg +mkLinearProg cs tyVars = + LP.LinearProg + { optType = Minimize, + objective = rankObj finalState, + -- let shape_vars = M.keys $ rankBinVars finalState + -- in foldr (\sv s -> var sv ~+~ s) (constant 0) shape_vars, + constraints = rankConstraints finalState + } + where + initState = + RankState + { rankBinVars = mempty, + rankCounter = 0, + rankConstraints = mempty, + rankObj = constant 0 + } + buildLP = do + mapM_ addCt cs + mapM_ (uncurry addTyVarInfo) $ M.toList tyVars + finalState = flip execState initState $ runRankM buildLP + +ambigCheckLinearProg :: LinearProg -> (Int, Map VName Int) -> LinearProg +ambigCheckLinearProg prog (opt, ranks) = + prog + { constraints = + constraints prog + -- https://yetanothermathprogrammingconsultant.blogspot.com/2011/10/integer-cuts.html + ++ [ lsum (var <$> M.keys one_bins) + ~-~ lsum (var <$> M.keys zero_bins) + ~<=~ constant (fromIntegral $ length one_bins) + ~-~ constant 1, + objective prog ~==~ constant (fromIntegral opt) + ] + } + where + -- We really need to track which variables are binary in the LinearProg + is_bin_var = ("b_" `L.isPrefixOf`) . baseString + one_bins = M.filterWithKey (\k v -> is_bin_var k && v == 1) ranks + zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks + lsum = foldr (~+~) (constant 0) + +enumerateRankSols :: LinearProg -> [Map VName Int] +enumerateRankSols prog = + take 5 $ + takeSolns $ + iterate next_sol $ + (prog,) <$> run_glpk prog + where + run_glpk = unsafePerformIO . glpk + next_sol m = do + (prog', sol') <- m + guard (fst sol' /= 0) + let prog'' = ambigCheckLinearProg prog' sol' + sol'' <- run_glpk prog'' + pure (prog'', sol'') + takeSolns [] = [] + takeSolns (Nothing : _) = [] + takeSolns (Just (_, (_, r)) : xs) = r : takeSolns xs + +solveRankILP :: (MonadTypeChecker m) => SrcLoc -> LinearProg -> m [Map VName Int] +solveRankILP loc prog = do + debugTraceM 3 $ + unlines + [ "## solveRankILP", + prettyString prog + ] + case enumerateRankSols prog of + [] -> typeError loc mempty "Rank ILP cannot be solved." + rs -> do + debugTraceM 3 "## rank maps" + forM_ (zip [0 :: Int ..] rs) $ \(i, r) -> + debugTraceM 3 $ + unlines $ + "\n## rank map " <> prettyString i + : map prettyString (M.toList r) + pure rs + +rankAnalysis1 :: + (MonadTypeChecker m) => + SrcLoc -> + [Ct] -> + TyVars -> + M.Map TyVar Type -> + [Pat ParamType] -> + Exp -> + Maybe (TypeExp Exp VName) -> + m + ( ([Ct], M.Map TyVar Type, TyVars), + [Pat ParamType], + Exp, + Maybe (TypeExp Exp VName) + ) +rankAnalysis1 loc cs tyVars artificial params body retdecl = do + solutions <- rankAnalysis loc cs tyVars artificial params body retdecl + case solutions of + [sol] -> pure sol + sols -> do + let (_, _, bodies', _) = L.unzip4 sols + typeError loc mempty $ + stack $ + [ "Rank ILP is ambiguous.", + "Choices:" + ] + ++ map pretty bodies' + +rankAnalysis :: + (MonadTypeChecker m) => + SrcLoc -> + [Ct] -> + TyVars -> + M.Map TyVar Type -> + [Pat ParamType] -> + Exp -> + Maybe (TypeExp Exp VName) -> + m + [ ( ([Ct], M.Map TyVar Type, TyVars), + [Pat ParamType], + Exp, + Maybe (TypeExp Exp VName) + ) + ] +rankAnalysis _ [] tyVars artificial params body retdecl = + pure [(([], artificial, tyVars), params, body, retdecl)] +rankAnalysis loc cs tyVars artificial params body retdecl = do + debugTraceM 3 $ + unlines + [ "##rankAnalysis", + "cs:", + unlines $ map prettyString cs, + "cs':", + unlines $ map prettyString cs' + ] + rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) + cts_tyvars' <- mapM (substRankInfo cs artificial tyVars) rank_maps + let bodys = map (`updAM` body) rank_maps + params' = map ((`map` params) . updAMPat) rank_maps + retdecls = map ((<$> retdecl) . updAMTypeExp) rank_maps + pure $ L.zip4 cts_tyvars' params' bodys retdecls + where + cs' = + foldMap distribAndSplitCnstrs $ + foldMap distribAndSplitArrows cs + +type RankMap = M.Map VName Int + +substRankInfo :: + (MonadTypeChecker m) => + [Ct] -> + M.Map VName Type -> + TyVars -> + RankMap -> + m ([Ct], M.Map VName Type, TyVars) +substRankInfo cs artificial tyVars rankmap = do + ((cs', artificial', tyVars'), new_cs, new_tyVars) <- + runSubstT tyVars rankmap $ + (,,) <$> substRanks (filter (not . isCtAM) cs) <*> traverse substRanks artificial <*> traverse substRanks tyVars + pure (cs' <> new_cs, artificial', new_tyVars <> tyVars') + where + isCtAM (CtAM {}) = True + isCtAM _ = False + +runSubstT :: (MonadTypeChecker m) => TyVars -> RankMap -> SubstT m a -> m (a, [Ct], TyVars) +runSubstT tyVars rankmap (SubstT m) = do + let env = + SubstEnv + { envTyVars = tyVars, + envRanks = rankmap + } + + s = + SubstState + { substTyVars = mempty, + substNewVars = mempty, + substNewCts = mempty + } + (a, s') <- runReaderT (runStateT m s) env + pure (a, substNewCts s', substTyVars s') + +newtype SubstT m a = SubstT (StateT SubstState (ReaderT SubstEnv m) a) + deriving + ( Functor, + Applicative, + Monad, + MonadState SubstState, + MonadReader SubstEnv + ) + +data SubstEnv = SubstEnv + { envTyVars :: TyVars, + envRanks :: RankMap + } + +data SubstState = SubstState + { substTyVars :: TyVars, + substNewVars :: Map TyVar TyVar, + substNewCts :: [Ct] + } + +instance MonadTrans SubstT where + lift = SubstT . lift . lift + +newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar +newTyVar t = do + t' <- lift $ newTypeName (baseName t) + shape <- rankToShape t + loc <- asks ((locOf . snd . fromJust . (M.!? t)) . envTyVars) + modify $ \s -> + s + { substNewVars = M.insert t t' $ substNewVars s, + substNewCts = + substNewCts s + ++ [ CtEq + (Reason loc) + (Scalar (TypeVar mempty (QualName [] t) [])) + (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) + ] + } + pure t' + +rankToShape :: (Monad m) => VName -> SubstT m (Shape SComp) +rankToShape x = do + rs <- asks envRanks + pure $ Shape $ replicate (fromJust $ rs M.!? x) SDim + +addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () +addRankInfo t = do + rs <- asks envRanks + if fromMaybe 0 (rs M.!? t) == 0 + then pure () + else do + new_vars <- gets substNewVars + maybe new_var (const $ pure ()) $ new_vars M.!? t + where + new_var = do + t' <- newTyVar t + old_tyvars <- asks envTyVars + let (level, tvinfo) = fromJust $ old_tyvars M.!? t + l = case tvinfo of + TyVarFree _ tvinfo_l -> tvinfo_l + _ -> Unlifted + modify $ \s -> s {substTyVars = M.insert t' (level, tvinfo) $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree (locOf tvinfo) l) $ substTyVars s} + +class SubstRanks a where + substRanks :: (MonadTypeChecker m) => a -> SubstT m a + +instance (SubstRanks a) => SubstRanks [a] where + substRanks = mapM substRanks + +instance SubstRanks (Shape SComp) where + substRanks = foldM (\s d -> (s <>) <$> instDim d) mempty + where + instDim SDim = pure $ Shape $ pure SDim + instDim (SVar x) = rankToShape x + +instance SubstRanks (TypeBase SComp u) where + substRanks t@(Scalar (TypeVar _ (QualName [] x) [])) = + addRankInfo x >> pure t + substRanks (Scalar (Arrow u p d ta (RetType retdims tr))) = do + ta' <- substRanks ta + tr' <- substRanks tr + pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) + substRanks (Scalar (Record fs)) = + Scalar . Record <$> traverse substRanks fs + substRanks (Scalar (Sum cs)) = + Scalar . Sum <$> (traverse . traverse) substRanks cs + substRanks (Array u shape t) = do + shape' <- substRanks shape + t' <- substRanks $ Scalar t + pure $ arrayOfWithAliases u shape' t' + substRanks t = pure t + +instance SubstRanks Ct where + substRanks (CtEq r t1 t2) = CtEq r <$> substRanks t1 <*> substRanks t2 + substRanks _ = error "" + +instance SubstRanks TyVarInfo where + substRanks tv@TyVarFree {} = pure tv + substRanks tv@TyVarPrim {} = pure tv + substRanks (TyVarRecord loc fs) = + TyVarRecord loc <$> traverse substRanks fs + substRanks (TyVarSum loc cs) = + TyVarSum loc <$> (traverse . traverse) substRanks cs + substRanks tv@TyVarEql {} = pure tv + +instance SubstRanks (Int, TyVarInfo) where + substRanks (lvl, tv) = (lvl,) <$> substRanks tv + +updAM :: RankMap -> Exp -> Exp +updAM rank_map e = + case e of + AppExp (Apply f args loc) res -> + let f' = updAM rank_map f + args' = fmap (bimap (fmap $ second upd) (updAM rank_map)) args + in AppExp (Apply f' args' loc) res + AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> + AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res + OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext, am)) t2 loc -> + OpSectionRight + name + t + (updAM rank_map arg) + (Info (pa, t1a), Info (pb, t1b, argext, upd am)) + t2 + loc + OpSectionLeft name t arg (Info (pa, t1a, argext, am), Info (pb, t1b)) (ret, retext) loc -> + OpSectionLeft + name + t + (updAM rank_map arg) + (Info (pa, t1a, argext, upd am), Info (pb, t1b)) + (ret, retext) + loc + _ -> runIdentity $ astMap mapper e + where + dimToRank (Var (QualName [] x) _ _) = + replicate (rank_map M.! x) (TupLit mempty mempty) + dimToRank e' = error $ prettyString e' + shapeToRank = Shape . foldMap dimToRank + upd (AutoMap r m f) = + AutoMap (shapeToRank r) (shapeToRank m) (shapeToRank f) + mapper = identityMapper {mapOnExp = pure . updAM rank_map} + +updAMPat :: RankMap -> Pat ParamType -> Pat ParamType +updAMPat rank_map p = runIdentity $ astMap m p + where + m = identityMapper {mapOnExp = pure . updAM rank_map} + +updAMTypeExp :: + RankMap -> + TypeExp Exp VName -> + TypeExp Exp VName +updAMTypeExp rank_map te = runIdentity $ astMap m te + where + m = identityMapper {mapOnExp = pure . updAM rank_map} diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 92aec1b91c..48f82ad5ea 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -26,7 +26,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S -import Futhark.Util (mapAccumLM, nubOrd) +import Futhark.Util (debugTraceM, mapAccumLM, nubOrd) import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Primitive (intByteSize) @@ -37,6 +37,7 @@ import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Terms.Loop import Language.Futhark.TypeChecker.Terms.Monad import Language.Futhark.TypeChecker.Terms.Pat +import Language.Futhark.TypeChecker.Terms2 qualified as Terms2 import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) @@ -53,12 +54,6 @@ hasBinding e = isNothing $ astMap m e m = identityMapper {mapOnExp = \e' -> if hasBinding e' then Nothing else Just e'} -overloadedTypeVars :: Constraints -> Names -overloadedTypeVars = mconcat . map f . M.elems - where - f (_, HasFields _ fs _) = mconcat $ map typeVars $ M.elems fs - f _ = mempty - --- Basic checking -- | Determine if the two types are identical, ignoring uniqueness. @@ -167,8 +162,8 @@ sliceShape r slice t@(Array u (Shape orig_dims) et) = ( BinOp (qualName (intrinsicVar "-"), mempty) sizeBinOpInfo - (j, Info Nothing) - (i, Info Nothing) + (j, Info (Nothing, mempty)) + (i, Info (Nothing, mempty)) mempty ) $ Info @@ -181,8 +176,8 @@ sliceShape _ _ t = pure (t, []) checkAscript :: SrcLoc -> - TypeExp (ExpBase NoInfo VName) VName -> - ExpBase NoInfo VName -> + TypeExp Exp VName -> + Exp -> TermTypeM (TypeExp Exp VName, Exp) checkAscript loc te e = do (te', decl_t, _) <- checkTypeExpNonrigid te @@ -196,8 +191,8 @@ checkAscript loc te e = do checkCoerce :: SrcLoc -> - TypeExp (ExpBase NoInfo VName) VName -> - ExpBase NoInfo VName -> + TypeExp Exp VName -> + Exp -> TermTypeM (TypeExp Exp VName, StructType, Exp) checkCoerce loc te e = do (te', te_t, ext) <- checkTypeExpNonrigid te @@ -260,48 +255,33 @@ unscopeType :: unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp -checkExp :: ExpBase NoInfo VName -> TermTypeM Exp +checkExp :: Exp -> TermTypeM Exp +checkExp (Var qn (Info t) loc) = do + t' <- lookupVar loc qn t + pure $ Var qn (Info t') loc checkExp (Literal val loc) = pure $ Literal val loc -checkExp (Hole _ loc) = do - t <- newTypeVar loc "t" - pure $ Hole (Info t) loc +checkExp (Hole (Info t) loc) = do + t' <- replaceTyVars loc t + pure $ Hole (Info t') loc checkExp (StringLit vs loc) = pure $ StringLit vs loc -checkExp (IntLit val NoInfo loc) = do - t <- newTypeVar loc "t" - mustBeOneOf anyNumberType (mkUsage loc "integer literal") t - pure $ IntLit val (Info t) loc -checkExp (FloatLit val NoInfo loc) = do - t <- newTypeVar loc "t" - mustBeOneOf anyFloatType (mkUsage loc "float literal") t - pure $ FloatLit val (Info t) loc +checkExp (IntLit val (Info t) loc) = do + t' <- replaceTyVars loc t + pure $ IntLit val (Info t') loc +checkExp (FloatLit val (Info t) loc) = do + t' <- replaceTyVars loc t + pure $ FloatLit val (Info t') loc checkExp (TupLit es loc) = TupLit <$> mapM checkExp es <*> pure loc checkExp (RecordLit fs loc) = - RecordLit <$> evalStateT (mapM checkField fs) mempty <*> pure loc + RecordLit <$> mapM checkField fs <*> pure loc where - checkField (RecordFieldExplicit f e rloc) = do - errIfAlreadySet (unLoc f) rloc - modify $ M.insert (unLoc f) rloc - RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc - checkField (RecordFieldImplicit name NoInfo rloc) = do - errIfAlreadySet (baseName (unLoc name)) rloc - t <- lift $ lookupVar rloc $ qualName $ unLoc name - modify $ M.insert (baseName (unLoc name)) rloc - pure $ RecordFieldImplicit name (Info t) rloc - - errIfAlreadySet f rloc = do - maybe_sloc <- gets $ M.lookup f - case maybe_sloc of - Just sloc -> - lift . typeError rloc mempty $ - "Field" - <+> dquotes (pretty f) - <+> "previously defined at" - <+> pretty (locStrRel rloc sloc) - <> "." - Nothing -> pure () + checkField (RecordFieldExplicit f e rloc) = + RecordFieldExplicit f <$> checkExp e <*> pure rloc + checkField (RecordFieldImplicit name (Info t) rloc) = do + t' <- lookupVar rloc (qualName (unLoc name)) t + pure $ RecordFieldImplicit name (Info t') rloc -- No need to type check this, as these are only produced by the -- parser if the elements are monomorphic and all match. checkExp (ArrayVal vs t loc) = @@ -316,15 +296,17 @@ checkExp (ArrayLit all_es _ loc) = [] -> do et <- newTypeVar loc "t" t <- arrayOfM loc et (Shape [sizeFromInteger 0 mempty]) + mustBeUnlifted (locOf loc) et pure $ ArrayLit [] (Info t) loc e : es -> do e' <- checkExp e et <- expType e' es' <- mapM (unifies "type of first array element" et <=< checkExp) es t <- arrayOfM loc et (Shape [sizeFromInteger (genericLength all_es) mempty]) + mustBeUnlifted (locOf loc) et pure $ ArrayLit (e' : es') (Info t) loc checkExp (AppExp (Range start maybe_step end loc) _) = do - start' <- require "use in range expression" anySignedType =<< checkExp start + start' <- checkExp start start_t <- expType start' maybe_step' <- case maybe_step of Nothing -> pure Nothing @@ -388,8 +370,8 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do ( BinOp (qualName (intrinsicVar op), mempty) sizeBinOpInfo - (x, Info Nothing) - (y, Info Nothing) + (x, Info (Nothing, mempty)) + (y, Info (Nothing, mempty)) mempty ) (Info $ AppRes t []) @@ -401,54 +383,57 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do checkExp (Ascript e te loc) = do (te', e') <- checkAscript loc te e pure $ Ascript e' te' loc -checkExp (Coerce e te NoInfo loc) = do +checkExp (Coerce e te _ loc) = do (te', te_t, e') <- checkCoerce loc te e t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc -checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do - ftype <- lookupVar oploc op +checkExp (AppExp (Apply fe args loc) _) = do + fe' <- checkExp fe + let ams = fmap (snd . unInfo . fst) args + args' <- mapM (checkExp . snd) args + t <- expType fe' + let fname = + case fe' of + Var v _ _ -> Just v + _ -> Nothing + ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) (NE.zip args' ams) + + pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts + where + onArg fname (i, all_exts, t) (arg', am) = do + (_, rt, argext, exts, am') <- checkApply loc (fname, i) t arg' am + pure + ( (i + 1, all_exts <> exts, rt), + (Info (argext, am'), arg') + ) +checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, Info (_, xam)) (e2, Info (_, yam)) loc) _) = do + ftype <- lookupVar oploc op op_t e1' <- checkExp e1 e2' <- checkExp e2 - -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. - (_, rt, p1_ext, _) <- checkApply loc (Just op, 0) ftype e1' - (_, rt', p2_ext, retext) <- checkApply loc (Just op, 1) rt e2' + (_, rt, p1_ext, _, am1) <- checkApply loc (Just op, 0) ftype e1' xam + (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) rt e2' yam pure $ AppExp ( BinOp (op, oploc) (Info ftype) - (e1', Info p1_ext) - (e2', Info p2_ext) + (e1', Info (p1_ext, am1)) + (e2', Info (p2_ext, am2)) loc ) (Info (AppRes rt' retext)) -checkExp (Project k e NoInfo loc) = do +checkExp (Project k e _ loc) = do e' <- checkExp e t <- expType e' - kt <- mustHaveField (mkUsage loc $ docText $ "projection of field " <> dquotes (pretty k)) k t - pure $ Project k e' (Info kt) loc -checkExp (AppExp (If e1 e2 e3 loc) _) = do - e1' <- checkExp e1 - e2' <- checkExp e2 - e3' <- checkExp e3 - - let bool = Scalar $ Prim Bool - e1_t <- expType e1' - onFailure (CheckingRequired [bool] e1_t) $ - unify (mkUsage e1' "use as 'if' condition") bool e1_t - - (brancht, retext) <- unifyBranches loc e2' e3' - - zeroOrderType - (mkUsage loc "returning value of this type from 'if' expression") - "type returned from branch" - brancht - - pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes brancht retext) + case t of + Scalar (Record fs) + | Just kt <- M.lookup k fs -> + pure $ Project k e' (Info kt) loc + _ -> error $ "checkExp Project: " <> show t checkExp (Parens e loc) = Parens <$> checkExp e <*> pure loc checkExp (QualParens (modname, modnameloc) e loc) = do @@ -460,33 +445,12 @@ checkExp (QualParens (modname, modnameloc) e loc) = do ModFun {} -> typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." -checkExp (Var qn NoInfo loc) = do - t <- lookupVar loc qn - pure $ Var qn (Info t) loc checkExp (Negate arg loc) = do - arg' <- require "numeric negation" anyNumberType =<< checkExp arg + arg' <- checkExp arg pure $ Negate arg' loc checkExp (Not arg loc) = do - arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg + arg' <- checkExp arg pure $ Not arg' loc -checkExp (AppExp (Apply fe args loc) NoInfo) = do - fe' <- checkExp fe - args' <- mapM (checkExp . snd) args - t <- expType fe' - let fname = - case fe' of - Var v _ _ -> Just v - _ -> Nothing - ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' - - pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts - where - onArg fname (i, all_exts, t) arg' = do - (_, rt, argext, exts) <- checkApply loc (fname, i) t arg' - pure - ( (i + 1, all_exts <> exts, rt), - (Info argext, arg') - ) checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e @@ -515,7 +479,7 @@ checkExp (AppExp (LetPat sizes pat e body loc) _) = do AppExp (LetPat sizes (fmap toStruct pat') e' body' loc) (Info $ AppRes body_t' retext) -checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) _) = do +checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, _, e) body loc) _) = do (tparams', params', maybe_retdecl', rettype, e') <- checkBinding (name, maybe_retdecl, tparams, params, e, loc) @@ -538,16 +502,18 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body l ) (Info $ AppRes body_t ext) checkExp (AppExp (LetWith dest src slice ve body loc) _) = do - src' <- checkIdent src + src_t <- lookupVar loc (qualName (identName src)) (unInfo (identType src)) + let src' = src {identType = Info src_t} + dest' = dest {identType = Info src_t} slice' <- checkSlice slice - (t, _) <- newArrayType (mkUsage src "type of source array") "src" $ sliceDims slice' + (t, _) <- newArrayType (mkUsage src' "type of source array") "src" $ sliceDims slice' unify (mkUsage loc "type of target array") t $ unInfo $ identType src' (elemt, _) <- sliceShape (Just (loc, Nonrigid)) slice' =<< normTypeFully t ve' <- unifies "type of target array" elemt =<< checkExp ve - bindingIdent dest (unInfo (identType src')) $ \dest' -> do + bindingIdent dest' $ do body' <- checkExp body (body_t, ext) <- unscopeType loc [identName dest'] =<< expTypeFully body' pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t ext) @@ -562,11 +528,9 @@ checkExp (Update src slice ve loc) = do -- Record updates are a bit hacky, because we do not have row typing -- (yet?). For now, we only permit record updates where we know the -- full type up to the field we are updating. -checkExp (RecordUpdate src fields ve NoInfo loc) = do +checkExp (RecordUpdate src fields ve _ loc) = do src' <- checkExp src ve' <- checkExp ve - a <- expTypeFully src' - foldM_ (flip $ mustHaveField usage) a fields ve_t <- expType ve' updated_t <- updateField fields ve_t =<< expTypeFully src' pure $ RecordUpdate src' fields ve' (Info updated_t) loc @@ -598,30 +562,35 @@ checkExp (AppExp (Index e slice loc) _) = do =<< expTypeFully e' pure $ AppExp (Index e' slice' loc) (Info $ AppRes t' retext) -checkExp (Assert e1 e2 NoInfo loc) = do - e1' <- require "being asserted" [Bool] =<< checkExp e1 +checkExp (Assert e1 e2 _ loc) = do + e1' <- checkExp e1 e2' <- checkExp e2 pure $ Assert e1' e2' (Info (prettyText e1)) loc -checkExp (Lambda params body rettype_te NoInfo loc) = do +checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do (params', body', rettype', RetType dims ty) <- incLevel . bindingParams [] params $ \params' -> do + rt' <- replaceTyVars loc rt rettype_checked <- traverse checkTypeExpNonrigid rettype_te - let declared_rettype = - case rettype_checked of - Just (_, st, _) -> Just st - Nothing -> Nothing + declared_rettype <- + case rettype_checked of + Just (_, st, _) -> do + unify (mkUsage body "lambda return type ascription") (toStruct rt') (toStruct st) + pure $ Just st + Nothing -> pure Nothing body' <- checkFunBody params' body declared_rettype loc body_t <- expTypeFully body' + unify (mkUsage body "inferred return type") (toStruct rt') body_t + params'' <- mapM updateTypes params' - (rettype', rettype_st) <- - case rettype_checked of - Just (te, st, ext) -> - pure (Just te, RetType ext st) - Nothing -> do - ret <- inferReturnSizes params'' $ toRes Nonunique body_t - pure (Nothing, ret) + (rettype', rettype_st) <- case rettype_checked of + Just (te, ret, ext) -> do + ret' <- normTypeFully ret + pure (Just te, RetType ext ret') + Nothing -> do + ret <- inferReturnSizes params'' $ toRes Nonunique body_t + pure (Nothing, ret) pure (params'', body', rettype', rettype_st) @@ -654,37 +623,38 @@ checkExp (Lambda params body rettype_te NoInfo loc) = do onDim _ = mempty pure $ RetType (S.toList $ foldMap onDim $ fvVars $ freeInType ret) ret -checkExp (OpSection op _ loc) = do - ftype <- lookupVar loc op +checkExp (OpSection op (Info op_t) loc) = do + ftype <- lookupVar loc op op_t pure $ OpSection op (Info ftype) loc -checkExp (OpSectionLeft op _ e _ _ loc) = do - ftype <- lookupVar loc op +checkExp (OpSectionLeft op (Info op_t) e (Info (_, _, _, am), _) _ loc) = do + ftype <- lookupVar loc op op_t e' <- checkExp e - (t1, rt, argext, retext) <- checkApply loc (Just op, 0) ftype e' + (t1, rt, argext, retext, am') <- checkApply loc (Just op, 0) ftype e' am case (ftype, rt) of - (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 rettype)) -> + (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 (RetType ds rt2))) -> pure $ OpSectionLeft op (Info ftype) e' - (Info (m1, toParam d1 t1, argext), Info (m2, toParam d2 t2)) - (Info rettype, Info retext) + (Info (m1, toParam d1 t1, argext, am'), Info (m2, toParam d2 t2)) + (Info $ RetType ds $ arrayOfWithAliases (uniqueness rt2) (autoFrame am') rt2, Info retext) loc _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (OpSectionRight op _ e _ NoInfo loc) = do - ftype <- lookupVar loc op +checkExp (OpSectionRight op (Info op_t) e (_, Info (_, _, _, am)) _ loc) = do + ftype <- lookupVar loc op op_t e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do - (t2', arrow', argext, _) <- + (t2', arrow', argext, _, am') <- checkApply loc (Just op, 1) (Scalar $ Arrow mempty m2 d2 t2 $ RetType [] $ Scalar $ Arrow Nonunique m1 d1 t1 $ RetType dims2 ret) e' + am case arrow' of Scalar (Arrow _ _ _ t1' (RetType dims2' ret')) -> pure $ @@ -692,20 +662,22 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do op (Info ftype) e' - (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext)) - (Info $ RetType dims2' ret') + (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext, am')) + (Info $ RetType dims2' $ arrayOfWithAliases (uniqueness ret') (autoFrame am') ret') loc _ -> error $ "OpSectionRight: impossible type\n" <> prettyString arrow' _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (ProjectSection fields NoInfo loc) = do - a <- newTypeVar loc "a" - let usage = mkUsage loc "projection at" - b <- foldM (flip $ mustHaveField usage) a fields - let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ toRes Nonunique b - pure $ ProjectSection fields (Info ft) loc -checkExp (IndexSection slice NoInfo loc) = do +checkExp (ProjectSection fields (Info t) loc) = do + t' <- replaceTyVars loc t + case t' of + Scalar (Arrow _ _ _ t'' (RetType _ rt)) + | Just ft <- recordField fields t'' -> + unify (mkUsage loc "result of projection") ft $ toStruct rt + _ -> error $ "checkExp ProjectSection: " <> show t' + pure $ ProjectSection fields (Info t') loc +checkExp (IndexSection slice _ loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' loc) "e" $ sliceDims slice' (t', retext) <- sliceShape Nothing slice' t @@ -718,27 +690,57 @@ checkExp (AppExp (Loop _ mergepat loopinit form loopbody loc) _) = do AppExp (Loop sparams mergepat' loopinit' form' loopbody' loc) (Info appres) -checkExp (Constr name es NoInfo loc) = do - t <- newTypeVar loc "t" +checkExp (Constr name es (Info t) loc) = do + t' <- replaceTyVars loc t es' <- mapM checkExp es - ets <- mapM expType es' - mustHaveConstr (mkUsage loc "use of constructor") name t ets - pure $ Constr name es' (Info t) loc + case t' of + Scalar (Sum cs) + | Just name_ts <- M.lookup name cs -> + zipWithM_ (unify $ mkUsage loc "inferred variant") name_ts $ + map typeOf es' + _ -> + error $ "checkExp Constr: " <> prettyString t' + pure $ Constr name es' (Info t') loc +checkExp (AppExp (If e1 e2 e3 loc) _) = do + e1' <- checkExp e1 + e2' <- checkExp e2 + e3' <- checkExp e3 + + let bool = Scalar $ Prim Bool + e1_t <- expType e1' + onFailure (CheckingRequired [bool] e1_t) $ + unify (mkUsage e1' "use as 'if' condition") bool e1_t + + (t, retext) <- unifyBranches loc e2' e3' + + mustBeOrderZero (locOf loc) t + + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes t retext) checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e mt <- expType e' (cs', t, retext) <- checkCases mt cs - zeroOrderType - (mkUsage loc "being returned 'match'") - "type returned from pattern match" - t + + mustBeOrderZero (locOf loc) t + pure $ AppExp (Match e' cs' loc) (Info $ AppRes t retext) checkExp (Attr info e loc) = Attr <$> checkAttr info <*> checkExp e <*> pure loc +checkCase :: + StructType -> + CaseBase Info VName -> + TermTypeM (CaseBase Info VName, StructType, [VName]) +checkCase mt (CasePat p e loc) = + bindingPat [] p mt $ \p' -> do + e' <- checkExp e + e_t <- expTypeFully e' + (e_t', retext) <- unscopeType loc (patNames p') e_t + pure (CasePat (fmap toStruct p') e' loc, e_t', retext) + checkCases :: StructType -> - NE.NonEmpty (CaseBase NoInfo VName) -> + NE.NonEmpty (CaseBase Info VName) -> TermTypeM (NE.NonEmpty (CaseBase Info VName), StructType, [VName]) checkCases mt rest_cs = case NE.uncons rest_cs of @@ -751,17 +753,6 @@ checkCases mt rest_cs = (brancht, retext) <- unifyBranchTypes (srclocOf c) c_t cs_t pure (NE.cons c' cs', brancht, retext) -checkCase :: - StructType -> - CaseBase NoInfo VName -> - TermTypeM (CaseBase Info VName, StructType, [VName]) -checkCase mt (CasePat p e loc) = - bindingPat [] p mt $ \p' -> do - e' <- checkExp e - e_t <- expTypeFully e' - (e_t', retext) <- unscopeType loc (patNames p') e_t - pure (CasePat (fmap toStruct p') e' loc, e_t', retext) - -- | An unmatched pattern. Used in in the generation of -- unmatched pattern warnings by the type checker. data Unmatched p @@ -790,22 +781,13 @@ instance Pretty (Unmatched (Pat StructType)) where pretty' (PatLit e _ _) = pretty e pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) -checkIdent :: IdentBase NoInfo VName StructType -> TermTypeM (Ident StructType) -checkIdent (Ident name _ loc) = do - vt <- lookupVar loc $ qualName name - pure $ Ident name (Info vt) loc - -checkSlice :: SliceBase NoInfo VName -> TermTypeM [DimIndex] +checkSlice :: SliceBase Info VName -> TermTypeM [DimIndex] checkSlice = mapM checkDimIndex where - checkDimIndex (DimFix i) = do - DimFix <$> (require "use as index" anySignedType =<< checkExp i) + checkDimIndex (DimFix i) = + DimFix <$> checkExp i checkDimIndex (DimSlice i j s) = - DimSlice <$> check i <*> check j <*> check s - - check = - maybe (pure Nothing) $ - fmap Just . unifies "use as index" (Scalar $ Prim $ Signed Int64) <=< checkExp + DimSlice <$> traverse checkExp i <*> traverse checkExp j <*> traverse checkExp s -- The number of dimensions affected by this slice (so the minimum -- rank of the array we are slicing). @@ -868,16 +850,55 @@ dimUses = flip execState mempty . traverseDims f where fv = freeInExp e `freeWithout` bound +splitArrayAt :: Int -> StructType -> (Shape Size, StructType) +splitArrayAt x t = + (Shape $ take x $ shapeDims $ arrayShape t, stripArray x t) + checkApply :: SrcLoc -> ApplyOp -> StructType -> Exp -> - TermTypeM (StructType, StructType, Maybe VName, [VName]) -checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do + AutoMap -> + TermTypeM (StructType, StructType, Maybe VName, [VName], AutoMap) +checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do - unify (mkUsage argexp "use as function argument") tp1 argtype + -- argtype = arg_frame argtype' + -- tp1 = f_frame tp1' + -- + -- Rep case: + -- R arg_frame argtype' = f_frame tp1' + -- ==> R = (autoRepRank am)-length prefix of tp1 + -- ==> frame = f_frame = (autoFrameRank am)-length prefix of tp1 + -- + -- Map case: + -- arg_frame argtype' = M f_frame tp1' + -- ==> M = (autoMapRank am)-length prefix of argtype + -- ==> frame = M f_frame = (autoFrameRank am)-length prefix of argtype + (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype + (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 + (am_frame_shape, _) <- + if autoMapRank am == 0 + then splitArrayAt (autoFrameRank am) <$> normTypeFully tp1 + else splitArrayAt (autoFrameRank am) <$> normTypeFully argtype + + debugTraceM 3 $ + unlines + [ "## checkApply", + "## fn", + prettyString fn, + "## ft", + prettyString ft, + "## tp1_with_frame", + prettyString tp1_with_frame, + "## argtype_with_frame", + prettyString argtype_with_frame, + "## am", + show am + ] + + unify (mkUsage argexp "use as function argument") tp1_with_frame argtype_with_frame -- Perform substitutions of instantiated variables in the types. (tp2', ext) <- instantiateDimsInReturnType loc fname =<< normTypeFully tp2 @@ -917,67 +938,60 @@ checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do in pure (Nothing, applySubst parsubst $ toStruct tp2') _ -> pure (Nothing, toStruct tp2') - pure (tp1, tp2'', argext, ext) -checkApply loc fname tfun@(Scalar TypeVar {}) arg = do - tv <- newTypeVar loc "b" - unify (mkUsage loc "use as function") tfun $ - Scalar (Arrow mempty Unnamed Observe (typeOf arg) $ RetType [] $ paramToRes tv) - tfun' <- normType tfun - checkApply loc fname tfun' arg -checkApply loc (fname, prev_applied) ftype argexp = do - let fname' = maybe "expression" (dquotes . pretty) fname - - typeError loc mempty $ - if prev_applied == 0 - then - "Cannot apply" - <+> fname' - <+> "as function, as it has type:" - indent 2 (pretty ftype) - else - "Cannot apply" - <+> fname' - <+> "to argument #" - <> pretty (prev_applied + 1) - <+> dquotes (shorten $ group $ pretty argexp) - <> "," - "as" - <+> fname' - <+> "only takes" - <+> pretty prev_applied - <+> arguments - <> "." + let am' = + AutoMap + { autoRep = am_rep_shape, + autoMap = am_map_shape, + autoFrame = am_frame_shape + } + + pure (tp1, distribute (arrayOf (autoMap am') tp2''), argext, ext, am') where - arguments - | prev_applied == 1 = "argument" - | otherwise = "arguments" + distribute :: TypeBase dim u -> TypeBase dim u + distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute (arrayOfWithAliases (uniqueness tr) s tr)) + distribute t = t +checkApply _ _ _ _ _ = + error "checkApply: array" -- | Type-check a single expression in isolation. This expression may -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) -checkOneExp e = runTermTypeM checkExp $ do - e' <- checkExp e - let t = typeOf e' - (tparams, _, _) <- - letGeneralise (nameFromString "") (srclocOf e) [] [] $ toRes Nonunique t - fixOverloadedTypes $ typeVars t - e'' <- normTypeFully e' - localChecks e'' - causalityCheck e'' - pure (tparams, e'') +checkOneExp e = do + (maybe_tysubsts, e') <- Terms2.checkSingleExp e + case maybe_tysubsts of + Left err -> throwError err + Right (_generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do + e'' <- checkExp e' + let t = typeOf e'' + (tparams, _, _) <- + letGeneralise (nameFromString "") (srclocOf e) [] [] $ toRes Nonunique t + fixOverloadedTypes $ typeVars t + e''' <- normTypeFully e'' + localChecks e''' + causalityCheck e''' + pure (tparams, e''') -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp -checkSizeExp e = runTermTypeM checkExp $ do - e' <- checkExp e - let t = typeOf e' - when (hasBinding e') $ - typeError (srclocOf e') mempty . withIndexLink "size-expression-bind" $ - "Size expression with binding is forbidden." - unify (mkUsage e' "Size expression") t (Scalar (Prim (Signed Int64))) - normTypeFully e' +checkSizeExp e = do + (maybe_tysubsts, e') <- Terms2.checkSizeExp e + case maybe_tysubsts of + Left err -> throwError err + Right (_generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do + e'' <- checkExp e' + when (hasBinding e'') $ + typeError (srclocOf e'') mempty . withIndexLink "size-expression-bind" $ + "Size expression with binding is forbidden." + normTypeFully e'' -- Verify that all sum type constructors and empty array literals have -- a size that is known (rigid or a type parameter). This is to @@ -1047,7 +1061,7 @@ causalityCheck binding_body = do seqArgs known' [] = do void $ onExp known' f modify (S.fromList (appResExt res) <>) - seqArgs known' ((Info p, x) : xs) = do + seqArgs known' ((Info (p, _), x) : xs) = do new_known <- collectingNewKnown $ onExp known' x void $ seqArgs (new_known <> known') xs modify ((new_known <> S.fromList (maybeToList p)) <>) @@ -1066,7 +1080,7 @@ causalityCheck binding_body = do modify (new_known <>) onExp known - e@(AppExp (BinOp (f, floc) ft (x, Info xp) (y, Info yp) _) (Info res)) = do + e@(AppExp (BinOp (f, floc) ft (x, Info (xp, _)) (y, Info (yp, _)) _) (Info res)) = do args_known <- collectingNewKnown $ sequencePoint known x y $ catMaybes [xp, yp] void $ onExp (args_known <> known) (Var f ft floc) @@ -1171,14 +1185,6 @@ localChecks = void . check e <$ case ty of Info (Scalar (Prim t)) -> errorBounds (inBoundsI (-x) t) (-x) t (loc1 <> loc2) _ -> error "Inferred type of int literal is not a number" - check e@(AppExp (BinOp (QualName [] v, _) _ (x, _) _ loc) _) - | baseName v == "==", - Array {} <- typeOf x, - baseTag v <= maxIntrinsicTag = do - warn loc $ - textwrap - "Comparing arrays with \"==\" is deprecated and will stop working in a future revision of the language." - recurse e check e = recurse e recurse = astMap identityMapper {mapOnExp = check} @@ -1203,107 +1209,18 @@ localChecks = void . check <> pretty ty <> "." --- | Type-check a top-level (or module-level) function definition. --- Despite the name, this is also used for checking constant --- definitions, by treating them as 0-ary functions. -checkFunDef :: - ( VName, - Maybe (TypeExp (ExpBase NoInfo VName) VName), - [TypeParam], - [PatBase NoInfo VName ParamType], - ExpBase NoInfo VName, - SrcLoc - ) -> - TypeM - ( [TypeParam], - [Pat ParamType], - Maybe (TypeExp Exp VName), - ResRetType, - Exp - ) -checkFunDef (fname, maybe_retdecl, tparams, params, body, loc) = - runTermTypeM checkExp $ do - (tparams', params', maybe_retdecl', RetType dims rettype', body') <- - checkBinding (fname, maybe_retdecl, tparams, params, body, loc) - - -- Since this is a top-level function, we also resolve overloaded - -- types, using either defaults or complaining about ambiguities. - fixOverloadedTypes $ - typeVars rettype' <> foldMap (typeVars . patternType) params' - - -- Then replace all inferred types in the body and parameters. - body'' <- normTypeFully body' - params'' <- mapM normTypeFully params' - maybe_retdecl'' <- traverse updateTypes maybe_retdecl' - rettype'' <- normTypeFully rettype' - - -- Check if the function body can actually be evaluated. - causalityCheck body'' - - -- Check for various problems. - mapM_ (mustBeIrrefutable . fmap toStruct) params' - localChecks body'' - - let ((body''', updated_ret), errors) = - Consumption.checkValDef - ( fname, - params'', - body'', - RetType dims rettype'', - maybe_retdecl'', - loc - ) - - mapM_ throwError errors - - pure (tparams', params'', maybe_retdecl'', updated_ret, body''') - -- | This is "fixing" as in "setting them", not "correcting them". We -- only make very conservative fixing. fixOverloadedTypes :: Names -> TermTypeM () fixOverloadedTypes tyvars_at_toplevel = getConstraints >>= mapM_ fixOverloaded . M.toList . M.map snd where - fixOverloaded (v, Overloaded ots usage) - | Signed Int32 `elem` ots = do - unify usage (Scalar (TypeVar mempty (qualName v) [])) $ - Scalar (Prim $ Signed Int32) - when (v `S.member` tyvars_at_toplevel) $ - warn usage "Defaulting ambiguous type to i32." - | FloatType Float64 `elem` ots = do - unify usage (Scalar (TypeVar mempty (qualName v) [])) $ - Scalar (Prim $ FloatType Float64) - when (v `S.member` tyvars_at_toplevel) $ - warn usage "Defaulting ambiguous type to f64." - | otherwise = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous (could be one of" - <+> commasep (map pretty ots) - <> ")." - "Add a type annotation to disambiguate the type." fixOverloaded (v, NoConstraint _ usage) = do -- See #1552. unify usage (Scalar (TypeVar mempty (qualName v) [])) $ Scalar (tupleRecord []) when (v `S.member` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to ()." - fixOverloaded (_, Equality usage) = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous (must be equality type)." - "Add a type annotation to disambiguate the type." - fixOverloaded (_, HasFields _ fs usage) = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous. Must be record with fields:" - indent 2 (stack $ map field $ M.toList fs) - "Add a type annotation to disambiguate the type." - where - field (l, t) = pretty l <> colon <+> align (pretty t) - fixOverloaded (_, HasConstrs _ cs usage) = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous (must be a sum type with constructors:" - <+> pretty (Sum cs) - <> ")." - "Add a type annotation to disambiguate the type." fixOverloaded (v, Size Nothing (Usage Nothing loc)) = typeError loc mempty . withIndexLink "ambiguous-size" $ "Ambiguous size" <+> dquotes (prettyName v) <> "." @@ -1335,10 +1252,10 @@ inferredReturnType loc params t = do checkBinding :: ( VName, - Maybe (TypeExp (ExpBase NoInfo VName) VName), + Maybe (TypeExp Exp VName), [TypeParam], - [PatBase NoInfo VName ParamType], - ExpBase NoInfo VName, + [PatBase Info VName ParamType], + ExpBase Info VName, SrcLoc ) -> TermTypeM @@ -1376,7 +1293,8 @@ checkBinding (fname, maybe_retdecl, tparams, params, body, loc) = verifyFunctionParams (Just fname) params'' (tparams', params''', rettype') <- - letGeneralise (baseName fname) loc tparams params'' =<< unscopeUnknown rettype + letGeneralise (baseName fname) loc tparams params'' + =<< unscopeUnknown rettype when ( null params @@ -1492,15 +1410,18 @@ closeOverTypes defname defloc tparams paramts ret substs = do case M.lookup v substs of Just (_, UnknownSize {}) -> Just v _ -> Nothing + pure - ( tparams ++ more_tparams, + ( tparams + ++ more_tparams, injectExt (nubOrd $ retext ++ mapMaybe mkExt (S.toList $ fvVars $ freeInType ret)) ret ) where -- Diet does not matter here. t = foldFunType (map (toParam Observe) paramts) $ RetType [] ret - to_close_over = M.filterWithKey (\k _ -> k `S.member` visible) substs visible = typeVars t <> fvVars (freeInType t) + to_close_over = + M.filterWithKey (\k _ -> k `S.member` visible) substs (produced_sizes, param_sizes) = dimUses t @@ -1547,19 +1468,13 @@ letGeneralise defname defloc tparams params restype = -- -- (2) are not used in the (new) definition of any type variables -- known before we checked this function. - -- - -- (3) are not referenced from an overloaded type (for example, - -- are the element types of an incompletely resolved record type). - -- This is a bit more restrictive than I'd like, and SML for - -- example does not have this restriction. - -- + -- Criteria (1) and (2) is implemented by looking at the binding -- level of the type variables. - let keep_type_vars = overloadedTypeVars now_substs cur_lvl <- curLevel - let candidate k (lvl, _) = (k `S.notMember` keep_type_vars) && lvl >= (cur_lvl - length params) - new_substs = M.filterWithKey candidate now_substs + let candidate (lvl, _) = lvl >= (cur_lvl - length params) + new_substs = M.filter candidate now_substs (tparams', RetType ret_dims restype') <- closeOverTypes @@ -1587,7 +1502,7 @@ letGeneralise defname defloc tparams params restype = checkFunBody :: [Pat ParamType] -> - ExpBase NoInfo VName -> + Exp -> Maybe ResType -> SrcLoc -> TermTypeM Exp @@ -1622,3 +1537,67 @@ arrayOfM :: arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t + +-- | Type-check a top-level (or module-level) function definition. +-- Despite the name, this is also used for checking constant +-- definitions, by treating them as 0-ary functions. +checkFunDef :: + ( VName, + Maybe (TypeExp (ExpBase NoInfo VName) VName), + [TypeParam], + [PatBase NoInfo VName ParamType], + ExpBase NoInfo VName, + SrcLoc + ) -> + TypeM + ( [TypeParam], + [Pat ParamType], + Maybe (TypeExp Exp VName), + ResRetType, + Exp + ) +checkFunDef (fname, retdecl, tparams, params, body, loc) = + doChecks =<< Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) + where + -- TODO: Print out the possibilities. (And also potentially eliminate + --- some of the possibilities to disambiguate). + + doChecks (maybe_tysubsts, params', retdecl', body') = + case maybe_tysubsts of + Left err -> throwError err + Right (generalised, tysubsts) -> + runTermTypeM checkExp tysubsts $ do + (tparams', params'', retdecl'', RetType dims rettype', body'') <- + checkBinding (fname, retdecl', generalised <> tparams, params', body', loc) + + -- Since this is a top-level function, we also resolve overloaded + -- types, using either defaults or complaining about ambiguities. + fixOverloadedTypes $ + typeVars rettype' <> foldMap (typeVars . patternType) params'' + + -- Then replace all inferred types in the body and parameters. + body''' <- normTypeFully body'' + params''' <- mapM normTypeFully params'' + retdecl''' <- traverse updateTypes retdecl'' + rettype'' <- normTypeFully rettype' + + -- Check if the function body can actually be evaluated. + causalityCheck body''' + + -- Check for various problems. + mapM_ (mustBeIrrefutable . fmap toStruct) params'' + localChecks body''' + + let ((body'''', updated_ret), errors) = + Consumption.checkValDef + ( fname, + params''', + body''', + RetType dims rettype'', + retdecl''', + loc + ) + + mapM_ throwError errors + + pure (tparams', params''', retdecl''', updated_ret, body'''') diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index 51067a6537..781b989b6f 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -102,7 +102,7 @@ wellTypedLoopArg src sparams pat arg = do -- | An un-checked loop. type UncheckedLoop = - (PatBase NoInfo VName ParamType, LoopInitBase NoInfo VName, LoopFormBase NoInfo VName, ExpBase NoInfo VName) + (Pat ParamType, LoopInitBase Info VName, LoopFormBase Info VName, Exp) -- | A loop that has been type-checked. type CheckedLoop = @@ -129,22 +129,14 @@ checkForImpossible loc known_before pat_t = do -- | Type-check a @loop@ expression, passing in a function for -- type-checking subexpressions. checkLoop :: - (ExpBase NoInfo VName -> TermTypeM Exp) -> + (Exp -> TermTypeM Exp) -> UncheckedLoop -> SrcLoc -> TermTypeM (CheckedLoop, AppRes) checkLoop checkExp (mergepat, loopinit, form, loopbody) loc = do - loopinit' <- checkExp $ case loopinit of - LoopInitExplicit e -> e - LoopInitImplicit _ -> - -- Should have been filled out in Names - error "Unspected LoopInitImplicit" + loopinit' <- checkExp $ loopInitExp loopinit known_before <- M.keysSet <$> getConstraints - zeroOrderType - (mkUsage loopinit' "use as loop variable") - "type used as loop variable" - . toStruct - =<< expTypeFully loopinit' + mustBeOrderZero (locOf loopinit') =<< expTypeFully loopinit' -- The handling of dimension sizes is a bit intricate, but very -- similar to checking a function, followed by checking a call to @@ -245,20 +237,18 @@ checkLoop checkExp (mergepat, loopinit, form, loopbody) loc = do (sparams, mergepat', form', loopbody') <- case form of For i uboundexp -> do - uboundexp' <- - require "being the bound in a 'for' loop" anySignedType - =<< checkExp uboundexp - bound_t <- expTypeFully uboundexp' - bindingIdent i bound_t $ \i' -> - bindingPat [] mergepat merge_t $ \mergepat' -> incLevel $ do - loopbody' <- checkExp loopbody - (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' - pure - ( sparams, - mergepat'', - For i' uboundexp', - loopbody' - ) + uboundexp' <- checkExp uboundexp + it <- expType uboundexp' + let i' = i {identType = Info it} + bindingIdent i' . bindingParam mergepat merge_t $ \mergepat' -> incLevel $ do + loopbody' <- checkExp loopbody + (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' + pure + ( sparams, + mergepat'', + For i' uboundexp', + loopbody' + ) ForIn xpat e -> do (arr_t, _) <- newArrayType (mkUsage' (srclocOf e)) "e" 1 e' <- unifies "being iterated in a 'for-in' loop" arr_t =<< checkExp e @@ -267,7 +257,7 @@ checkLoop checkExp (mergepat, loopinit, form, loopbody) loc = do _ | Just t' <- peelArray 1 t -> bindingPat [] xpat t' $ \xpat' -> - bindingPat [] mergepat merge_t $ \mergepat' -> incLevel $ do + bindingParam mergepat merge_t $ \mergepat' -> incLevel $ do loopbody' <- checkExp loopbody (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' pure @@ -281,7 +271,7 @@ checkLoop checkExp (mergepat, loopinit, form, loopbody) loc = do "Iteratee of a for-in loop must be an array, but expression has type" <+> pretty t While cond -> - bindingPat [] mergepat merge_t $ \mergepat' -> + bindingParam mergepat merge_t $ \mergepat' -> incLevel $ do cond' <- checkExp cond diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index aa95a3cbad..ba30adbaeb 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -25,12 +25,15 @@ module Language.Futhark.TypeChecker.Terms.Monad constrain, newArrayType, allDimsFreshInType, + instTyVars, + replaceTyVars, updateTypes, Names, + mustBeOrderZero, + mustBeUnlifted, -- * Primitive checking unifies, - require, checkTypeExpNonrigid, lookupVar, lookupMod, @@ -50,8 +53,9 @@ import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State.Strict +import Data.Bifunctor import Data.Bitraversable -import Data.Char (isAscii) +import Data.Foldable import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S @@ -61,6 +65,8 @@ import Futhark.FreshNames qualified import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals +import Language.Futhark.TypeChecker.Constraints (TyVar) +import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod, stateNameSource) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types @@ -83,6 +89,7 @@ unusedSize p = data Inferred t = NoneInferred | Ascribed t + deriving (Show) instance Functor Inferred where fmap _ NoneInferred = NoneInferred @@ -94,7 +101,7 @@ data Checking | CheckingAscription StructType StructType | CheckingLetGeneralise Name | CheckingParams (Maybe Name) - | CheckingPat (PatBase NoInfo VName StructType) (Inferred StructType) + | CheckingPat (PatBase Info VName StructType) (Inferred StructType) | CheckingLoopBody StructType StructType | CheckingLoopInitial StructType StructType | CheckingRecordUpdate [Name] StructType StructType @@ -196,8 +203,9 @@ data TermEnv = TermEnv { termScope :: TermScope, termChecking :: Maybe Checking, termLevel :: Level, - termChecker :: ExpBase NoInfo VName -> TermTypeM Exp, + termCheckExp :: ExpBase Info VName -> TermTypeM Exp, termOuterEnv :: Env, + termTyVars :: M.Map TyVar (TypeBase () NoUniqueness), termImportName :: ImportName } @@ -345,43 +353,120 @@ instance MonadUnify TermTypeM where indent 2 (pretty t2) "do not match." --- | Instantiate a type scheme with fresh type variables for its type --- parameters. Returns the names of the fresh type variables, the --- instance list, and the instantiated type. -instantiateTypeScheme :: +replaceTyVars :: SrcLoc -> TypeBase Size u -> TermTypeM (TypeBase Size u) +replaceTyVars loc orig_t = do + tyvars <- asks termTyVars + let f :: TypeBase Size u -> TermTypeM (TypeBase Size u) + f (Scalar (Prim t)) = pure $ Scalar $ Prim t + f + (Scalar (TypeVar u (QualName [] v) [])) + | Just t <- M.lookup v tyvars = + fst <$> allDimsFreshInType (mkUsage loc "replaceTyVars") Nonrigid "dv" (second (const u) t) + | otherwise = + pure $ Scalar (TypeVar u (QualName [] v) []) + f (Scalar (TypeVar u qn targs)) = + Scalar . TypeVar u qn <$> mapM onTyArg targs + where + onTyArg (TypeArgDim e) = pure $ TypeArgDim e + onTyArg (TypeArgType t) = TypeArgType <$> f t + f (Scalar (Record fs)) = + Scalar . Record <$> traverse f fs + f (Scalar (Sum fs)) = + Scalar . Sum <$> traverse (mapM f) fs + f (Scalar (Arrow u pname d ta (RetType ext tr))) = do + ta' <- f ta + tr' <- f tr + pure $ Scalar $ Arrow u pname d ta' $ RetType ext tr' + f (Array u shape t) = + arrayOfWithAliases u shape <$> f (Scalar t) + + f orig_t + +instTyVars :: + SrcLoc -> + [VName] -> + TypeBase () u -> + TypeBase Size u -> + TermTypeM (TypeBase Size u) +instTyVars loc names orig_t1 orig_t2 = do + tyvars <- asks termTyVars + let f :: + TypeBase d u -> + TypeBase Size u -> + StateT (M.Map VName (TypeBase Size NoUniqueness)) TermTypeM (TypeBase Size u) + f + (Scalar (TypeVar u (QualName [] v1) [])) + t2 + | Just t <- M.lookup v1 tyvars = + f (second (const u) t) t2 + f (Scalar (Record fs1)) (Scalar (Record fs2)) = + Scalar . Record <$> sequence (M.intersectionWith f fs1 fs2) + f (Scalar (Sum fs1)) (Scalar (Sum fs2)) = + Scalar . Sum <$> sequence (M.intersectionWith (zipWithM f) fs1 fs2) + f + (Scalar (Arrow u _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ pname d t2a (RetType ext t2r))) = do + ta <- f t1a t2a + tr <- f t1r t2r + pure $ Scalar $ Arrow u pname d ta $ RetType ext tr + f + (Array u (Shape (_ : ds1)) t1) + (Array _ (Shape (d : ds2)) t2) = + arrayOfWithAliases u (Shape [d]) + <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) + f + (Scalar (TypeVar u v1 targs1)) + (Scalar (TypeVar _ _ targs2)) + | length targs1 == length targs2 = + Scalar . TypeVar u v1 <$> zipWithM g targs1 targs2 + where + g (TypeArgType t1) (TypeArgType t2) = + TypeArgType <$> f t1 t2 + g _ targ = pure targ + f t1 t2 = do + let mkNew = + fst <$> lift (allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" t1) + case t2 of + Scalar (TypeVar u (QualName [] v2) []) + | v2 `elem` names -> do + seen <- get + case M.lookup v2 seen of + Nothing -> do + t <- mkNew + modify $ M.insert v2 $ second (const NoUniqueness) t + pure t + Just t -> + pure $ second (const u) t + _ -> mkNew + + evalStateT (f orig_t1 orig_t2) mempty + +-- | Instantiate a type scheme with fresh variables for its size and +-- type parameters. Returns the names of the fresh size and type +-- variables and the instantiated type. +instTypeScheme :: QualName VName -> SrcLoc -> [TypeParam] -> StructType -> TermTypeM ([VName], StructType) -instantiateTypeScheme qn loc tparams t = do - let tnames = map typeParamName tparams - (tparam_names, tparam_substs) <- mapAndUnzipM (instantiateTypeParam qn loc) tparams - let substs = M.fromList $ zip tnames tparam_substs - t' = applySubst (`M.lookup` substs) t - pure (tparam_names, t') - --- | Create a new type name and insert it (unconstrained) in the --- substitution map. -instantiateTypeParam :: - (Monoid as) => - QualName VName -> - SrcLoc -> - TypeParam -> - TermTypeM (VName, Subst (RetTypeBase dim as)) -instantiateTypeParam qn loc tparam = do - i <- incCounter - let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) - v <- newID $ mkTypeVarName name i - case tparam of - TypeParamType x _ _ -> do - constrain v . NoConstraint x . mkUsage loc . docText $ - "instantiated type parameter of " <> dquotes (pretty qn) - pure (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) []) - TypeParamDim {} -> do - constrain v . Size Nothing . mkUsage loc . docText $ - "instantiated size parameter of " <> dquotes (pretty qn) - pure (v, ExpSubst $ sizeFromName (qualName v) loc) +instTypeScheme qn loc tparams scheme_t = do + (names, substs) <- fmap unzip . forM tparams $ \tparam -> do + case tparam of + TypeParamType l v _ -> do + i <- incCounter + v' <- newID $ mkTypeVarName (baseName v) i + constrain v' . NoConstraint l . mkUsage loc . docText $ + "instantiated type parameter of " <> dquotes (pretty qn) + pure (v', (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v') [])) + TypeParamDim v _ -> do + i <- incCounter + v' <- newID $ mkTypeVarName (baseName v) i + constrain v' . Size Nothing . mkUsage loc . docText $ + "instantiated size parameter of " <> dquotes (pretty qn) + pure (v', (v, ExpSubst $ sizeFromName (qualName v') loc)) + + pure (names, applySubst (`lookup` substs) scheme_t) lookupQualNameEnv :: QualName VName -> TermTypeM TermScope lookupQualNameEnv (QualName [q] _) @@ -446,41 +531,23 @@ instance MonadTypeChecker TermTypeM where Nothing -> throwError $ TypeError (locOf loc) notes s -lookupVar :: SrcLoc -> QualName VName -> TermTypeM StructType -lookupVar loc qn@(QualName qs name) = do +lookupVar :: SrcLoc -> QualName VName -> StructType -> TermTypeM StructType +lookupVar loc qn@(QualName qs name) inst_t = do scope <- lookupQualNameEnv qn - let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) - case M.lookup name $ scopeVtable scope of Nothing -> error $ "lookupVar: " <> show qn - Just (BoundV tparams t) -> do + Just (BoundV tparams bound_t) -> if null tparams && null qs - then pure t + then pure bound_t else do - (tnames, t') <- instantiateTypeScheme qn loc tparams t + (tnames, t) <- instTypeScheme qn loc tparams bound_t outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' - Just EqualityF -> do - argtype <- newTypeVar loc "t" - equalityType usage argtype - pure $ - Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ - Scalar $ - Arrow mempty Unnamed Observe argtype $ - RetType [] $ - Scalar $ - Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeVar loc "t" - mustBeOneOf ts usage argtype - let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' - where - instOverloaded argtype pts rt = - ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, - maybe (toStruct argtype) (Scalar . Prim) rt - ) + pure $ qualifyTypeVars outer_env tnames qs t + Just EqualityF -> + replaceTyVars loc inst_t + Just OverloadedF {} -> + replaceTyVars loc inst_t onFailure :: Checking -> TermTypeM a -> TermTypeM a onFailure c = local $ \env -> env {termChecking = Just c} @@ -531,8 +598,8 @@ allDimsFreshInType :: Usage -> Rigidity -> Name -> - TypeBase Size als -> - TermTypeM (TypeBase Size als, M.Map VName Size) + TypeBase d als -> + TermTypeM (TypeBase Size als, M.Map VName d) allDimsFreshInType usage r desc t = runStateT (bitraverse onDim pure t) mempty where @@ -554,6 +621,33 @@ updateTypes = astMap tv mapOnResRetType = normTypeFully } +mustBeOrderZero :: Loc -> StructType -> TermTypeM () +mustBeOrderZero loc t = do + constraints <- getConstraints + let liftedType v = + case M.lookup v constraints of + Just (_, ParamType Lifted _) -> True + _ -> False + when (not (orderZero t) || any liftedType (typeVars t)) $ + typeError loc mempty $ + textwrap "This expression may not be of function type, but is inferred to be of type" + indent 2 (align (pretty t)) + "which may be a function." + +mustBeUnlifted :: Loc -> StructType -> TermTypeM () +mustBeUnlifted loc t = do + constraints <- getConstraints + let liftedType v = + case M.lookup v constraints of + Just (_, ParamType Lifted _) -> True + Just (_, ParamType SizeLifted _) -> True + _ -> False + when (not (orderZero t) || any liftedType (typeVars t)) $ + typeError loc mempty $ + textwrap "This expression must be of unlifted type, but is inferred to be of type" + indent 2 (align (pretty t)) + "which may be a function or a value with hidden sizes." + --- Basic checking unifies :: T.Text -> StructType -> Exp -> TermTypeM Exp @@ -561,24 +655,15 @@ unifies why t e = do unify (mkUsage (srclocOf e) why) t . toStruct =<< expType e pure e --- | @require ts e@ causes a 'TypeError' if @expType e@ is not one of --- the types in @ts@. Otherwise, simply returns @e@. -require :: T.Text -> [PrimType] -> Exp -> TermTypeM Exp -require why ts e = do - mustBeOneOf ts (mkUsage (srclocOf e) why) . toStruct =<< expType e - pure e - -checkExpForSize :: ExpBase NoInfo VName -> TermTypeM Exp +checkExpForSize :: ExpBase Info VName -> TermTypeM Exp checkExpForSize e = do - checker <- asks termChecker + checker <- asks termCheckExp e' <- checker e let t = toStruct $ typeOf e' unify (mkUsage (locOf e') "Size expression") t (Scalar (Prim (Signed Int64))) updateTypes e' -checkTypeExpNonrigid :: - TypeExp (ExpBase NoInfo VName) VName -> - TermTypeM (TypeExp Exp VName, ResType, [VName]) +checkTypeExpNonrigid :: TypeExp Exp VName -> TermTypeM (TypeExp Exp VName, ResType, [VName]) checkTypeExpNonrigid te = do (te', svars, rettype, _l) <- checkTypeExp checkExpForSize te @@ -632,8 +717,8 @@ initialTermScope = Just (name, EqualityF) addIntrinsicF _ = Nothing -runTermTypeM :: (ExpBase NoInfo VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a -runTermTypeM checker (TermTypeM m) = do +runTermTypeM :: (ExpBase Info VName -> TermTypeM Exp) -> M.Map TyVar (TypeBase () NoUniqueness) -> TermTypeM a -> TypeM a +runTermTypeM checker tyvars (TermTypeM m) = do initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv name <- askImportName outer_env <- askEnv @@ -643,9 +728,10 @@ runTermTypeM checker (TermTypeM m) = do { termScope = initial_scope, termChecking = Nothing, termLevel = 0, - termChecker = checker, + termCheckExp = checker, termImportName = name, - termOuterEnv = outer_env + termOuterEnv = outer_env, + termTyVars = tyvars } initial_state = TermTypeState diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 9f33a42602..2576922bf3 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -2,6 +2,7 @@ module Language.Futhark.TypeChecker.Terms.Pat ( binding, bindingParams, + bindingParam, bindingPat, bindingIdent, bindingSizes, @@ -11,7 +12,7 @@ where import Control.Monad import Data.Bifunctor import Data.Either -import Data.List (find, isPrefixOf, sort, sortBy) +import Data.List (find, isPrefixOf, sortBy) import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord (comparing) @@ -100,48 +101,30 @@ bindingSizes sizes m = binding (map sizeWithType sizes) m Ident (sizeName size) (Info (Scalar (Prim (Signed Int64)))) (srclocOf size) -- | Bind a single term-level identifier. -bindingIdent :: - IdentBase NoInfo VName StructType -> - StructType -> - (Ident StructType -> TermTypeM a) -> - TermTypeM a -bindingIdent (Ident v NoInfo vloc) t m = do - let ident = Ident v (Info t) vloc - binding [ident] $ m ident - --- All this complexity is just so we can handle un-suffixed numeric --- literals in patterns. -patLitMkType :: PatLit -> SrcLoc -> TermTypeM ParamType -patLitMkType (PatLitInt _) loc = do - t <- newTypeVar loc "t" - mustBeOneOf anyNumberType (mkUsage loc "integer literal") (toStruct t) - pure t -patLitMkType (PatLitFloat _) loc = do - t <- newTypeVar loc "t" - mustBeOneOf anyFloatType (mkUsage loc "float literal") (toStruct t) - pure t -patLitMkType (PatLitPrim v) _ = - pure $ Scalar $ Prim $ primValueType v +bindingIdent :: Ident StructType -> TermTypeM a -> TermTypeM a +bindingIdent ident = binding [ident] checkPat' :: [(SizeBinder VName, QualName VName)] -> - PatBase NoInfo VName ParamType -> + Pat ParamType -> Inferred ParamType -> TermTypeM (Pat ParamType) checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' sizes p t <*> pure loc -checkPat' _ (Id name NoInfo loc) (Ascribed t) = - pure $ Id name (Info t) loc -checkPat' _ (Id name NoInfo loc) NoneInferred = do - t <- newTypeVar loc "t" - pure $ Id name (Info t) loc -checkPat' _ (Wildcard _ loc) (Ascribed t) = - pure $ Wildcard (Info t) loc -checkPat' _ (Wildcard NoInfo loc) NoneInferred = do - t <- newTypeVar loc "t" - pure $ Wildcard (Info t) loc +checkPat' _ (Id name (Info t) loc) NoneInferred = do + t' <- replaceTyVars loc t + pure $ Id name (Info t') loc +checkPat' _ (Id name (Info t1) loc) (Ascribed t2) = do + t' <- instTyVars loc [] (first (const ()) t1) t2 + pure $ Id name (Info t') loc +checkPat' _ (Wildcard (Info t) loc) NoneInferred = do + t' <- replaceTyVars loc t + pure $ Wildcard (Info t') loc +checkPat' _ (Wildcard (Info t1) loc) (Ascribed t2) = do + t' <- instTyVars loc [] (first (const ()) t1) t2 + pure $ Wildcard (Info t') loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, length ts == length ps = @@ -169,11 +152,6 @@ checkPat' sizes p@(RecordPat p_fs loc) (Ascribed t) RecordPat <$> zipWithM check p_fs' t_fs' <*> pure loc | otherwise = do p_fs' <- traverse (const $ newTypeVar loc "t") $ M.fromList $ map (first unLoc) p_fs - - when (sort (M.keys p_fs') /= sort (map (unLoc . fst) p_fs)) $ - typeError loc mempty $ - "Duplicate fields in record pattern" <+> pretty p <> "." - unify (mkUsage loc "matching a record pattern") (Scalar (Record p_fs')) (toStruct t) checkPat' sizes p $ Ascribed $ toParam Observe $ Scalar (Record p_fs') where @@ -199,54 +177,33 @@ checkPat' sizes (PatAscription p t loc) maybe_outer_t = do <$> checkPat' sizes p (Ascribed (resToParam st)) <*> pure t' <*> pure loc -checkPat' _ (PatLit l NoInfo loc) (Ascribed t) = do - t' <- patLitMkType l loc - unify (mkUsage loc "matching against literal") (toStruct t') (toStruct t) - pure $ PatLit l (Info t') loc -checkPat' _ (PatLit l NoInfo loc) NoneInferred = do - t' <- patLitMkType l loc +checkPat' _ (PatLit l (Info t) loc) _ = do + t' <- replaceTyVars loc t pure $ PatLit l (Info t') loc -checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) +checkPat' sizes (PatConstr n info ps loc) NoneInferred = do + ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps + pure $ PatConstr n info ps' loc +checkPat' sizes (PatConstr n _ ps loc) (Ascribed (Scalar (Sum cs))) | Just ts <- M.lookup n cs = do - when (length ps /= length ts) $ - typeError loc mempty $ - "Pattern #" - <> pretty n - <> " expects" - <+> pretty (length ps) - <+> "constructor arguments, but type provides" - <+> pretty (length ts) - <+> "arguments." - ps' <- zipWithM (checkPat' sizes) ps $ map Ascribed ts + ps' <- zipWithM (\p t -> checkPat' sizes p (Ascribed t)) ps ts pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc -checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed t) = do - t' <- newTypeVar loc "t" - ps' <- forM ps $ \p -> do - p_t <- newTypeVar (srclocOf p) "t" - checkPat' sizes p $ Ascribed p_t - mustHaveConstr usage n (toStruct t') (patternStructType <$> ps') - unify usage t' (toStruct t) - pure $ PatConstr n (Info t) ps' loc - where - usage = mkUsage loc "matching against constructor" -checkPat' sizes (PatConstr n NoInfo ps loc) NoneInferred = do - ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps - t <- newTypeVar loc "t" - mustHaveConstr usage n (toStruct t) (patternStructType <$> ps') - pure $ PatConstr n (Info t) ps' loc - where - usage = mkUsage loc "matching against constructor" +checkPat' _ p t = + error . unlines $ + [ "checkPat': bad case", + prettyString p, + show t + ] checkPat :: [(SizeBinder VName, QualName VName)] -> - PatBase NoInfo VName (TypeBase Size u) -> + Pat ParamType -> Inferred StructType -> (Pat ParamType -> TermTypeM a) -> TermTypeM a checkPat sizes p t m = do p' <- onFailure (CheckingPat (fmap toStruct p) t) $ - checkPat' sizes (fmap (toParam Observe) p) (fmap (toParam Observe) t) + checkPat' sizes p (fmap (toParam Observe) t) let explicit = mustBeExplicitInType $ patternStructType p' @@ -259,19 +216,30 @@ checkPat sizes p t m = do [] -> m p' +-- | Check and bind a single parameter. +bindingParam :: + Pat ParamType -> + StructType -> + (Pat ParamType -> TermTypeM a) -> + TermTypeM a +bindingParam p t m = do + checkPat mempty p (Ascribed t) $ \p' -> + binding (patIdents (fmap toStruct p')) $ m p' + -- | Check and bind a @let@-pattern. bindingPat :: [SizeBinder VName] -> - PatBase NoInfo VName (TypeBase Size u) -> + Pat (TypeBase Size u) -> StructType -> (Pat ParamType -> TermTypeM a) -> TermTypeM a bindingPat sizes p t m = do substs <- mapM mkSizeSubst sizes - checkPat substs p (Ascribed t) $ \p' -> binding (patIdents (fmap toStruct p')) $ - case filter ((`S.notMember` fvVars (freeInPat p')) . sizeName) sizes of - [] -> m p' - size : _ -> unusedSize size + checkPat substs (fmap (toParam Observe) p) (Ascribed t) $ \p' -> + binding (patIdents (fmap toStruct p')) $ + case filter ((`S.notMember` fvVars (freeInPat p')) . sizeName) sizes of + [] -> m p' + size : _ -> unusedSize size where mkSizeSubst v = do v' <- newID $ baseName $ sizeName v @@ -282,13 +250,15 @@ bindingPat sizes p t m = do -- | Check and bind type and value parameters. bindingParams :: [TypeParam] -> - [PatBase NoInfo VName ParamType] -> + [Pat ParamType] -> ([Pat ParamType] -> TermTypeM a) -> TermTypeM a bindingParams tps orig_ps m = bindingTypeParams tps $ do let descend ps' (p : ps) = checkPat [] p NoneInferred $ \p' -> - binding (patIdents $ fmap toStruct p') $ incLevel $ descend (p' : ps') ps + binding (patIdents $ fmap toStruct p') $ + incLevel $ + descend (p' : ps') ps descend ps' [] = m $ reverse ps' incLevel $ descend [] orig_ps diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs new file mode 100644 index 0000000000..a5895e708e --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -0,0 +1,1465 @@ +-- | A very WIP reimplementation of type checking of terms. +-- +-- The strategy is to split type checking into two (main) passes: +-- +-- 1) A size-agnostic pass that generates constraints (type Ct) which +-- are then solved offline to find a solution. This produces an AST +-- where most of the type annotations are just references to type +-- variables. Further, all the size-specific annotations (e.g. +-- existential sizes) just contain dummy values, such as empty lists. +-- The constraints use a type representation where all dimensions are +-- the same. However, we do try to take to store the sizes resulting +-- from explicit type ascriptions - these cannot refer to inferred +-- existentials, so it is safe to resolve them here. We don't do +-- anything with this information, however. +-- +-- 2) Pass (1) has given us a program where we know the types of +-- everything, but the sizes of nothing. Pass (2) then does +-- essentially size inference, much like the current/old type checker, +-- but of course with the massive benefit of already knowing the full +-- type of everything. This can be implemented using online constraint +-- solving (as before), or perhaps a completely syntax-driven +-- approach. +-- +-- As of this writing, only the constraint generation part of pass (1) +-- has been implemented, and it is very likely that some of the +-- constraints are actually wrong. Next step is to imlement the +-- solver. Currently all we do is dump the constraints to the +-- terminal. +-- +-- Also, no thought whatsoever has been put into quality of type +-- errors yet. However, I think an approach based on tacking source +-- information onto constraints should work well, as all constraints +-- ultimately originate from some bit of program syntax. +-- +-- Also no thought has been put into how to handle the liftedness +-- stuff. Since it does not really affect choices made during +-- inference, perhaps we can do it in a post-inference check. +module Language.Futhark.TypeChecker.Terms2 + ( checkValDef, + checkSingleExp, + checkSizeExp, + Solution, + ) +where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Data.Bifunctor +import Data.Bitraversable +import Data.Char (isAscii) +import Data.Either (partitionEithers) +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Loc (Loc (NoLoc)) +import Data.Map qualified as M +import Data.Maybe +import Data.Ord (comparing) +import Data.Set qualified as S +import Data.Text qualified as T +import Futhark.FreshNames qualified as FreshNames +import Futhark.MonadFreshNames hiding (newName) +import Futhark.Util (debugTraceM, mapAccumLM, nubOrd) +import Futhark.Util.Pretty +import Language.Futhark +import Language.Futhark.TypeChecker.Constraints +import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) +import Language.Futhark.TypeChecker.Monad qualified as TypeM +import Language.Futhark.TypeChecker.Rank +import Language.Futhark.TypeChecker.Types +import Language.Futhark.TypeChecker.Unify (Level, mkUsage) +import Prelude hiding (mod) + +data Inferred t + = NoneInferred + | Ascribed t + +instance Functor Inferred where + fmap _ NoneInferred = NoneInferred + fmap f (Ascribed t) = Ascribed (f t) + +data ValBinding + = BoundV [TypeParam] Type + | OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType) + | EqualityF + deriving (Show) + +data TermScope = TermScope + { scopeVtable :: M.Map VName ValBinding, + scopeTypeTable :: M.Map VName TypeBinding, + scopeModTable :: M.Map VName Mod + } + deriving (Show) + +instance Semigroup TermScope where + TermScope vt1 tt1 mt1 <> TermScope vt2 tt2 mt2 = + TermScope (vt2 `M.union` vt1) (tt2 `M.union` tt1) (mt1 `M.union` mt2) + +-- | Type checking happens with access to this environment. The +-- 'TermScope' will be extended during type-checking as bindings come into +-- scope. +data TermEnv = TermEnv + { termScope :: TermScope, + termLevel :: Level, + termOuterEnv :: Env, + termImportName :: ImportName + } + +-- | The state is a set of constraints and a counter for generating +-- type names. This is distinct from the usual counter we use for +-- generating unique names, as these will be user-visible. +data TermState = TermState + { termConstraints :: Constraints, + termTyVars :: TyVars, + termTyParams :: TyParams, + termCounter :: !Int, + termWarnings :: Warnings, + termNameSource :: VNameSource, + -- | Mapping from artificial type variables to the actual types they represent. + termArtificial :: M.Map TyVar Type + } + +newtype TermM a + = TermM + ( ReaderT + TermEnv + (StateT TermState (Except (Warnings, TypeError))) + a + ) + deriving + ( Monad, + Functor, + Applicative, + MonadReader TermEnv, + MonadState TermState + ) + +envToTermScope :: Env -> TermScope +envToTermScope env = + TermScope + { scopeVtable = vtable, + scopeTypeTable = envTypeTable env, + scopeModTable = envModTable env + } + where + vtable = M.map valBinding $ envVtable env + valBinding (TypeM.BoundV tps v) = BoundV tps $ toType v + +initialTermScope :: TermScope +initialTermScope = + TermScope + { scopeVtable = initialVtable, + scopeTypeTable = mempty, + scopeModTable = mempty + } + where + initialVtable = M.fromList $ mapMaybe addIntrinsicF $ M.toList intrinsics + + prim = Scalar . Prim + arrow x y = Scalar $ Arrow mempty Unnamed Observe x y + + addIntrinsicF (name, IntrinsicMonoFun pts t) = + Just (name, BoundV [] $ arrow pts' $ RetType [] $ prim t) + where + pts' = case pts of + [pt] -> prim pt + _ -> Scalar $ tupleRecord $ map prim pts + addIntrinsicF (name, IntrinsicOverloadedFun ts pts rts) = + Just (name, OverloadedF ts pts rts) + addIntrinsicF (name, IntrinsicPolyFun tvs pts rt) = + Just + ( name, + BoundV tvs $ toType $ foldFunType pts rt + ) + addIntrinsicF (name, IntrinsicEquality) = + Just (name, EqualityF) + addIntrinsicF _ = Nothing + +runTermM :: TermM a -> TypeM a +runTermM (TermM m) = do + initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv + name <- askImportName + outer_env <- askEnv + src <- gets stateNameSource + let initial_env = + TermEnv + { termScope = initial_scope, + termLevel = 0, + termImportName = name, + termOuterEnv = outer_env + } + initial_state = + TermState + { termConstraints = mempty, + termTyVars = mempty, + termTyParams = mempty, + termWarnings = mempty, + termNameSource = src, + termCounter = 0, + termArtificial = mempty + } + case runExcept (runStateT (runReaderT m initial_env) initial_state) of + Left (ws, e) -> do + warnings ws + throwError e + Right (a, TermState {termNameSource, termWarnings}) -> do + warnings termWarnings + modify $ \s -> s {stateNameSource = termNameSource} + pure a + +incLevel :: TermM a -> TermM a +incLevel = local $ \env -> env {termLevel = termLevel env + 1} + +curLevel :: TermM Int +curLevel = asks termLevel + +incCounter :: TermM Int +incCounter = do + s <- get + put s {termCounter = termCounter s + 1} + pure $ termCounter s + +tyVarType :: u -> TyVar -> TypeBase dim u +tyVarType u v = Scalar $ TypeVar u (qualName v) [] + +newTyVarWith :: Name -> TyVarInfo -> TermM TyVar +newTyVarWith desc info = do + i <- incCounter + v <- newID $ mkTypeVarName desc i + lvl <- curLevel + modify $ \s -> s {termTyVars = M.insert v (lvl, info) $ termTyVars s} + pure v + +newTyVar :: (Located loc) => loc -> Liftedness -> Name -> TermM TyVar +newTyVar loc l desc = newTyVarWith desc $ TyVarFree (locOf loc) l + +newType :: (Located loc) => loc -> Liftedness -> Name -> u -> TermM (TypeBase dim u) +newType loc l desc u = tyVarType u <$> newTyVar loc l desc + +-- | New type that must be allowed as an array element. +newElemType :: (Located loc) => loc -> Name -> u -> TermM (TypeBase dim u) +newElemType loc desc u = tyVarType u <$> newTyVar loc Unlifted desc + +newTypeWithField :: SrcLoc -> Name -> Name -> Type -> TermM Type +newTypeWithField loc desc k t = + tyVarType NoUniqueness + <$> newTyVarWith desc (TyVarRecord (locOf loc) $ M.singleton k t) + +newTypeWithConstr :: SrcLoc -> Name -> u -> Name -> [TypeBase SComp u] -> TermM (TypeBase d u) +newTypeWithConstr loc desc u k ts = + tyVarType u <$> newTyVarWith desc (TyVarSum (locOf loc) $ M.singleton k ts') + where + ts' = map (`setUniqueness` NoUniqueness) ts + +newTypeOverloaded :: SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d NoUniqueness) +newTypeOverloaded loc name pts = + tyVarType NoUniqueness <$> newTyVarWith name (TyVarPrim (locOf loc) pts) + +newSVar :: loc -> Name -> TermM SVar +newSVar _loc desc = do + i <- incCounter + newID $ mkTypeVarName desc i + +newArtificial :: u -> TypeBase SComp u -> TermM (TypeBase Size u) +newArtificial u t = do + v <- newID "artificial" + let t' = tyVarType u v + modify $ \s -> s {termArtificial = M.insert v (second (const NoUniqueness) t) $ termArtificial s} + pure t' + +-- The AST requires annotations to be StructTypes, but the type +-- checker works with Types. This creates artificial type "variables" +-- that allow us to connect the AST annotations with the actual +-- inferred types. The artificial variables should never occur in +-- constraints - they can be substituted away with asType. +asStructType :: TypeBase SComp u -> TermM (TypeBase Size u) +asStructType (Scalar (Prim pt)) = pure $ Scalar $ Prim pt +asStructType (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] +asStructType (Scalar (Arrow u pname d t1 (RetType ext t2))) = do + t1' <- asStructType t1 + t2' <- asStructType t2 + pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' +asStructType (Scalar (Record fs)) = + Scalar . Record <$> traverse asStructType fs +asStructType (Scalar (Sum cs)) = + Scalar . Sum <$> traverse (mapM asStructType) cs +asStructType t@(Scalar (TypeVar u _ _)) = + newArtificial u t +asStructType t@(Array u _ _) = do + newArtificial u t + +asType :: (Monoid u) => TypeBase Size u -> TermM (TypeBase SComp u) +asType t = do + artificial <- gets termArtificial + pure $ substTyVars (`M.lookup` artificial) (toType t) + +expType :: Exp -> TermM Type +expType = asType . typeOf -- NOTE: Only place you should use typeOf. + +addCt :: Ct -> TermM () +addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} + +ctEq :: Reason -> TypeBase SComp u1 -> TypeBase SComp u2 -> TermM () +ctEq reason t1 t2 = + -- As a minor optimisation, do not add constraint if the types are + -- equal. + unless (t1' == t2') $ addCt $ CtEq reason t1' t2' + where + t1' = t1 `setUniqueness` NoUniqueness + t2' = t2 `setUniqueness` NoUniqueness + +ctAM :: Reason -> SVar -> SVar -> Shape SComp -> TermM () +ctAM reason r m f = addCt $ CtAM reason r m f + +localScope :: (TermScope -> TermScope) -> TermM a -> TermM a +localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} + +withEnv :: TermEnv -> Env -> TermEnv +withEnv tenv env = tenv {termScope = termScope tenv <> envToTermScope env} + +lookupQualNameEnv :: QualName VName -> TermM TermScope +lookupQualNameEnv (QualName [q] _) + | baseTag q <= maxIntrinsicTag = asks termScope -- Magical intrinsic module. +lookupQualNameEnv qn@(QualName quals _) = do + scope <- asks termScope + descend scope quals + where + descend scope [] = pure scope + descend scope (q : qs) + | Just (ModEnv q_scope) <- M.lookup q $ scopeModTable scope = + descend (envToTermScope q_scope) qs + | otherwise = + error $ "lookupQualNameEnv " <> show qn + +instance MonadError TypeError TermM where + throwError e = TermM $ do + ws <- gets termWarnings + throwError (ws, e) + + catchError (TermM m) f = + TermM $ m `catchError` f' + where + f' (_, e) = let TermM m' = f e in m' + +instance MonadTypeChecker TermM where + warnings ws = modify $ \s -> s {termWarnings = termWarnings s <> ws} + + warn loc problem = warnings $ singleWarning (locOf loc) problem + + newName v = do + s <- get + let (v', src') = FreshNames.newName (termNameSource s) v + put $ s {termNameSource = src'} + pure v' + + newID s = newName $ VName s 0 + + newTypeName name = do + i <- incCounter + newID $ mkTypeVarName name i + + bindVal v (TypeM.BoundV tps t) m = do + t' <- asType t + let f scope = scope {scopeVtable = M.insert v (BoundV tps t') $ scopeVtable scope} + localScope f m + + lookupType qn = do + outer_env <- asks termOuterEnv + scope <- lookupQualNameEnv qn + case M.lookup (qualLeaf qn) $ scopeTypeTable scope of + Nothing -> error $ "lookupType: " <> show qn + Just (TypeAbbr l ps (RetType dims def)) -> + pure + ( ps, + RetType dims $ qualifyTypeVars outer_env (map typeParamName ps) (qualQuals qn) def, + l + ) + + typeError loc notes s = + throwError $ TypeError (locOf loc) notes s + +--- All the general machinery goes above. + +arrayOfRank :: Int -> Type -> Type +arrayOfRank n = arrayOf $ Shape $ replicate n SDim + +require :: T.Text -> [PrimType] -> Exp -> TermM Exp +require _why [pt] e = do + e_t <- expType e + ctEq (Reason (locOf e)) (Scalar $ Prim pt) e_t + pure e +require _why pts e = do + t :: Type <- newTypeOverloaded (srclocOf e) "t" pts + e_t <- expType e + ctEq (Reason (locOf e)) t e_t + pure e + +-- | Instantiate a type scheme with fresh type variables for its type +-- parameters. Returns the names of the fresh type variables, the +-- instance list, and the instantiated type. +instTypeScheme :: + QualName VName -> + SrcLoc -> + [TypeParam] -> + Type -> + TermM ([VName], Type) +instTypeScheme _qn loc tparams t = do + (names, substs) <- fmap (unzip . catMaybes) $ + forM tparams $ \tparam -> + case tparam of + TypeParamType l v _ -> do + v' <- newTyVar loc l $ nameFromString $ takeWhile isAscii $ baseString v + pure $ Just (v, (typeParamName tparam, tyVarType NoUniqueness v')) + TypeParamDim {} -> + pure Nothing + let t' = substTyVars (`lookup` substs) t + pure (names, t') + +lookupMod :: QualName VName -> TermM Mod +lookupMod qn@(QualName _ name) = do + scope <- lookupQualNameEnv qn + case M.lookup name $ scopeModTable scope of + Nothing -> error $ "lookupMod: " <> show qn + Just m -> pure m + +lookupVar :: SrcLoc -> QualName VName -> TermM Type +lookupVar loc qn@(QualName qs name) = do + scope <- lookupQualNameEnv qn + case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams t) -> do + if null tparams && null qs + then pure t + else do + (_tnames, t') <- instTypeScheme qn loc tparams t + -- TODO - qualify type names, like in the old type checker. + pure t' + Just EqualityF -> do + argtype <- tyVarType Observe <$> newTyVarWith "t" (TyVarEql (locOf loc)) + pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newTypeOverloaded loc "t" ts + let (pts', rt') = instOverloaded argtype pts rt + pure $ foldFunType (map (second $ const Observe) pts') $ RetType [] $ second (const Nonunique) rt' + where + instOverloaded argtype pts rt = + ( map (maybe argtype (Scalar . Prim)) pts, + maybe argtype (Scalar . Prim) rt + ) + +bind :: + [Ident StructType] -> + TermM a -> + TermM a +bind idents m = do + let names = map identName idents + ts <- mapM (asType . unInfo . identType) idents + localScope (`bindVars` zip names ts) m + where + bindVars = foldl bindVar + + bindVar scope (name, t) = + scope + { scopeVtable = M.insert name (BoundV [] t) $ scopeVtable scope + } + +-- All this complexity is just so we can handle un-suffixed numeric +-- literals in patterns. +patLitMkType :: PatLit -> SrcLoc -> TermM ParamType +patLitMkType (PatLitInt _) loc = + toParam Observe <$> newTypeOverloaded loc "t" anyNumberType +patLitMkType (PatLitFloat _) loc = + toParam Observe <$> newTypeOverloaded loc "t" anyFloatType +patLitMkType (PatLitPrim v) _ = + pure $ Scalar $ Prim $ primValueType v + +checkSizeExp' :: ExpBase NoInfo VName -> TermM Exp +checkSizeExp' e = do + e' <- checkExp e + e_t <- expType e' + ctEq (Reason (locOf e)) e_t (Scalar (Prim (Signed Int64))) + pure e' + +checkPat' :: + PatBase NoInfo VName ParamType -> + Inferred (TypeBase SComp Diet) -> + TermM (Pat ParamType) +checkPat' (PatParens p loc) t = + PatParens <$> checkPat' p t <*> pure loc +checkPat' (PatAttr attr p loc) t = + PatAttr <$> checkAttr attr <*> checkPat' p t <*> pure loc +checkPat' (Id name NoInfo loc) (Ascribed t) = do + t' <- asStructType t + pure $ Id name (Info t') loc +checkPat' (Id name NoInfo loc) NoneInferred = do + t <- newType loc Lifted "t" Observe + pure $ Id name (Info t) loc +checkPat' (Wildcard _ loc) (Ascribed t) = do + t' <- asStructType t + pure $ Wildcard (Info t') loc +checkPat' (Wildcard NoInfo loc) NoneInferred = do + t <- newType loc Lifted "t" Observe + pure $ Wildcard (Info t) loc +checkPat' (TuplePat ps loc) (Ascribed t) + | Just ts <- isTupleRecord t, + length ts == length ps = + TuplePat + <$> zipWithM checkPat' ps (map Ascribed ts) + <*> pure loc + | otherwise = do + ps_tvs <- replicateM (length ps) (newTyVar loc Lifted "t") + ctEq + (ReasonPatMatch (locOf loc) (TuplePat ps loc) (toStruct t)) + (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) + t + TuplePat <$> zipWithM checkPat' ps (map (Ascribed . tyVarType Observe) ps_tvs) <*> pure loc +checkPat' (TuplePat ps loc) NoneInferred = + TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc +checkPat' p@(RecordPat p_fs loc) _ + | Just (L floc f, _) <- L.find (("_" `T.isPrefixOf`) . nameToText . unLoc . fst) p_fs = + typeError floc mempty $ + "Underscore-prefixed fields are not allowed." + "Did you mean" + <> dquotes (pretty (T.drop 1 (nameToText f)) <> "=_") + <> "?" + | length (nubOrd (map fst p_fs)) /= length (map fst p_fs) = + typeError loc mempty $ + "Duplicate fields in record pattern" <+> pretty p <> "." +checkPat' p@(RecordPat p_fs loc) (Ascribed t) + | Scalar (Record t_fs) <- t, + p_fs' <- L.sortBy (comparing fst) p_fs, + t_fs' <- L.sortBy (comparing fst) (M.toList t_fs), + map fst t_fs' == map (unLoc . fst) p_fs' = + RecordPat <$> zipWithM check p_fs' t_fs' <*> pure loc + | otherwise = do + p_fs' <- + traverse (const $ newType loc Lifted "t" NoUniqueness) $ + M.fromList $ + map (first unLoc) p_fs + ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) t + checkPat' p $ Ascribed $ Observe <$ Scalar (Record p_fs') + where + check (L f_loc f, p_f) (_, t_f) = + (L f_loc f,) <$> checkPat' p_f (Ascribed t_f) +checkPat' (RecordPat fs loc) NoneInferred = + RecordPat . M.toList + <$> traverse (`checkPat'` NoneInferred) (M.fromList fs) + <*> pure loc +checkPat' (PatAscription p t loc) maybe_outer_t = do + (t', _, RetType _ st, _) <- checkTypeExp checkSizeExp' t + + -- Uniqueness kung fu to make the Monoid(mempty) instance give what + -- we expect. We should perhaps stop being so implicit. + st' <- asType $ resToParam st + + case maybe_outer_t of + Ascribed outer_t -> do + ctEq + (ReasonAscription (locOf loc) (toStruct st') (toStruct outer_t)) + st' + outer_t + PatAscription + <$> checkPat' p (Ascribed st') + <*> pure t' + <*> pure loc + NoneInferred -> + PatAscription + <$> checkPat' p (Ascribed st') + <*> pure t' + <*> pure loc +checkPat' (PatLit l NoInfo loc) (Ascribed t) = do + t' <- patLitMkType l loc + ctEq (Reason (locOf loc)) (toType t') t + pure $ PatLit l (Info t') loc +checkPat' (PatLit l NoInfo loc) NoneInferred = do + t' <- patLitMkType l loc + pure $ PatLit l (Info t') loc +checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) + | Just ts <- M.lookup n cs = do + when (length ps /= length ts) $ + typeError loc mempty $ + "Pattern #" + <> pretty n + <> " expects" + <+> pretty (length ps) + <+> "constructor arguments, but type provides" + <+> pretty (length ts) + <+> "arguments." + ps' <- zipWithM checkPat' ps $ map Ascribed ts + cs' <- traverse (mapM asStructType) cs + pure $ PatConstr n (Info (Scalar (Sum cs'))) ps' loc +checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do + ps' <- forM ps $ \p -> do + p_t <- newType (srclocOf p) Lifted "t" Observe + checkPat' p $ Ascribed p_t + t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' + ctEq (Reason (locOf loc)) t' t + t'' <- asStructType t' + pure $ PatConstr n (Info $ toParam Observe t'') ps' loc +checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do + ps' <- mapM (`checkPat'` NoneInferred) ps + t <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' + t' <- asStructType t + pure $ PatConstr n (Info $ toParam Observe t') ps' loc + +checkPat :: + PatBase NoInfo VName (TypeBase Size u) -> + Inferred Type -> + (Pat ParamType -> TermM a) -> + TermM a +checkPat p t m = + m =<< checkPat' (fmap (toParam Observe) p) (fmap (fmap (const Observe)) t) + +-- | Bind @let@-bound sizes. This is usually followed by 'bindletPat' +-- immediately afterwards. +bindSizes :: [SizeBinder VName] -> TermM a -> TermM a +bindSizes [] m = m -- Minor optimisation. +bindSizes sizes m = bind (map sizeWithType sizes) m + where + sizeWithType size = + Ident (sizeName size) (Info (Scalar (Prim (Signed Int64)))) (srclocOf size) + +bindLetPat :: + PatBase NoInfo VName (TypeBase Size u) -> + Type -> + (Pat ParamType -> TermM a) -> + TermM a +bindLetPat p t m = do + checkPat p (Ascribed t) $ \p' -> + bind (patIdents (fmap toStruct p')) $ + m p' + +typeParamIdent :: TypeParam -> Maybe (Ident StructType) +typeParamIdent (TypeParamDim v loc) = + Just $ Ident v (Info $ Scalar $ Prim $ Signed Int64) loc +typeParamIdent _ = Nothing + +bindTypes :: + [(VName, TypeBinding)] -> + TermM a -> + TermM a +bindTypes tbinds = localScope extend + where + extend scope = + scope + { scopeTypeTable = M.fromList tbinds <> scopeTypeTable scope + } + +bindTypeParams :: [TypeParam] -> TermM a -> TermM a +bindTypeParams tparams m = + bind idents . bindTypes types $ do + lvl <- curLevel + modify $ \s -> + s + { termTyParams = + termTyParams s + <> M.fromList (mapMaybe (typeParam lvl) tparams) + } + m + where + idents = mapMaybe typeParamIdent tparams + types = mapMaybe typeParamType tparams + typeParamType (TypeParamType l v _) = + Just (v, TypeAbbr l [] $ RetType [] $ Scalar (TypeVar mempty (qualName v) [])) + typeParamType TypeParamDim {} = Nothing + typeParam lvl (TypeParamType l v loc) = Just (v, (lvl, l, locOf loc)) + typeParam _ _ = Nothing + +bindParams :: + [TypeParam] -> + [PatBase NoInfo VName ParamType] -> + ([Pat ParamType] -> TermM a) -> + TermM a +bindParams tps orig_ps m = bindTypeParams tps $ do + let descend ps' (p : ps) = + checkPat p NoneInferred $ \p' -> + bind (patIdents $ fmap toStruct p') $ incLevel $ descend (p' : ps') ps + descend ps' [] = m $ reverse ps' + + incLevel $ descend [] orig_ps + +checkApplyOne :: + SrcLoc -> + (Maybe (QualName VName), Int) -> + (Shape Size, Type) -> + (Maybe Exp, Shape Size, Type) -> + TermM (Type, AutoMap) +checkApplyOne loc fname (fframe, ftype) (arg, argframe, argtype) = do + (a, b) <- split ftype + r <- newSVar loc "R" + m <- newSVar loc "M" + let unit_info = Info $ Scalar $ Prim Bool + r_var = Var (QualName [] r) unit_info mempty + m_var = Var (QualName [] m) unit_info mempty + lhs = arrayOf (toShape (SVar r)) argtype + rhs = arrayOf (toShape (SVar m)) a + ctAM (Reason (locOf loc)) r m $ fmap toSComp (toShape m_var <> fframe) + let reason = case arg of + Just arg' -> + ReasonApply (locOf loc) (fst fname) arg' lhs rhs + Nothing -> Reason (locOf loc) + ctEq reason lhs rhs + debugTraceM 3 $ + unlines + [ "## checkApplyOne", + "## fname", + prettyString fname, + "## (fframe, ftype)", + prettyString (fframe, ftype), + "## (argframe, argtype)", + prettyString (argframe, argtype), + "## r", + prettyString r, + "## m", + prettyString m, + "## lhs", + prettyString lhs, + "## rhs", + prettyString rhs, + "## ret", + prettyString $ arrayOf (toShape (SVar m)) b + ] + pure + ( arrayOf (toShape (SVar m)) b, + AutoMap + { autoRep = toShape r_var, + autoMap = toShape m_var, + autoFrame = toShape m_var <> fframe + } + ) + where + toSComp (Var (QualName [] x) _ _) = SVar x + toSComp _ = error "" + toShape = Shape . pure + split (Scalar (Arrow _ _ _ a (RetType _ b))) = + pure (a, b `setUniqueness` NoUniqueness) + split (Array _u s t) = do + (a, b) <- split $ Scalar t + pure (arrayOf s a, arrayOf s b) + split ftype' = do + a <- newType loc Lifted "arg" NoUniqueness + b <- newType loc Lifted "res" Nonunique + ctEq (Reason (locOf loc)) ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b + pure (a, b `setUniqueness` NoUniqueness) + +checkApply :: + SrcLoc -> + Maybe (QualName VName) -> + (Shape Size, Type) -> + NE.NonEmpty (Maybe Exp, Shape Size, Type) -> + TermM (Type, NE.NonEmpty AutoMap) +checkApply loc fname (fframe, ftype) args = do + ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args + pure (rt, argts) + where + onArg (i, f_f, f_t) arg = do + (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) arg + pure + ( (i + 1, autoFrame am, rt), + am + ) + +checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] +checkSlice = mapM checkDimIndex + where + checkDimIndex (DimFix i) = + DimFix <$> (require "use as index" anySignedType =<< checkExp i) + checkDimIndex (DimSlice i j s) = + DimSlice <$> traverse check i <*> traverse check j <*> traverse check s + + check = require "use in slice" [Signed Int64] <=< checkExp + +isSlice :: DimIndexBase f vn -> Bool +isSlice DimSlice {} = True +isSlice DimFix {} = False + +-- Add constraints saying that the first type has a (potentially +-- nested) field containing the second type. +mustHaveFields :: SrcLoc -> Type -> [Name] -> Type -> TermM () +mustHaveFields loc t [] ve_t = + -- This case is probably never reached. + ctEq (Reason (locOf loc)) t ve_t +mustHaveFields loc t [f] ve_t = do + rt :: Type <- newTypeWithField loc "ft" f ve_t + ctEq (Reason (locOf loc)) t rt +mustHaveFields loc t (f : fs) ve_t = do + ft <- newType loc Lifted "ft" NoUniqueness + rt <- newTypeWithField loc "rt" f ft + mustHaveFields loc ft fs ve_t + ctEq (Reason (locOf loc)) t rt + +checkCase :: + Type -> + CaseBase NoInfo VName -> + TermM (CaseBase Info VName, Type) +checkCase mt (CasePat p e loc) = + bindLetPat p mt $ \p' -> do + e' <- checkExp e + e_t <- expType e' + pure (CasePat (fmap toStruct p') e' loc, e_t) + +checkCases :: + Type -> + NE.NonEmpty (CaseBase NoInfo VName) -> + TermM (NE.NonEmpty (CaseBase Info VName), Type) +checkCases mt rest_cs = + case NE.uncons rest_cs of + (c, Nothing) -> do + (c', t) <- checkCase mt c + pure (NE.singleton c', t) + (c, Just cs) -> do + (c', c_t) <- checkCase mt c + (cs', cs_t) <- checkCases mt cs + ctEq (ReasonBranches (locOf c) c_t cs_t) c_t cs_t + pure (NE.cons c' cs', c_t) + +-- | An unmatched pattern. Used in in the generation of +-- unmatched pattern warnings by the type checker. +data Unmatched p + = UnmatchedNum p [PatLit] + | UnmatchedBool p + | UnmatchedConstr p + | Unmatched p + deriving (Functor, Show) + +instance Pretty (Unmatched (Pat StructType)) where + pretty um = case um of + (UnmatchedNum p nums) -> pretty' p <+> "where p is not one of" <+> pretty nums + (UnmatchedBool p) -> pretty' p + (UnmatchedConstr p) -> pretty' p + (Unmatched p) -> pretty' p + where + pretty' (PatAscription p t _) = pretty p <> ":" <+> pretty t + pretty' (PatParens p _) = parens $ pretty' p + pretty' (PatAttr _ p _) = parens $ pretty' p + pretty' (Id v _ _) = prettyName v + pretty' (TuplePat pats _) = parens $ commasep $ map pretty' pats + pretty' (RecordPat fs _) = braces $ commasep $ map ppField fs + where + ppField (name, t) = prettyName (unLoc name) <> equals <> pretty' t + pretty' Wildcard {} = "_" + pretty' (PatLit e _ _) = pretty e + pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) + +checkRetDecl :: + Exp -> + Maybe (TypeExp (ExpBase NoInfo VName) VName) -> + TermM (Type, Maybe (TypeExp Exp VName)) +checkRetDecl body Nothing = (,Nothing) <$> expType body +checkRetDecl body (Just te) = do + (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te + body_t <- expType body + st' <- toStruct <$> asType st + ctEq (ReasonRetType (locOf body) st' body_t) st' body_t + pure (st', Just te') + +checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) +-- +checkExp (Var qn _ loc) = do + t <- asStructType =<< lookupVar loc qn + pure $ Var qn (Info t) loc +checkExp (OpSection op _ loc) = do + ftype <- asStructType =<< lookupVar loc op + pure $ OpSection op (Info ftype) loc +checkExp (Negate arg loc) = do + arg' <- require "numeric negation" anyNumberType =<< checkExp arg + pure $ Negate arg' loc +checkExp (Not arg loc) = do + arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg + pure $ Not arg' loc +checkExp (Hole NoInfo loc) = + Hole <$> (Info <$> newType loc Lifted "hole" NoUniqueness) <*> pure loc +checkExp (Parens e loc) = + Parens <$> checkExp e <*> pure loc +checkExp (TupLit es loc) = + TupLit <$> mapM checkExp es <*> pure loc +checkExp (QualParens (modname, modnameloc) e loc) = do + mod <- lookupMod modname + case mod of + ModEnv env -> local (`withEnv` env) $ do + e' <- checkExp e + pure $ QualParens (modname, modnameloc) e' loc + ModFun {} -> + typeError loc mempty . withIndexLink "module-is-parametric" $ + "Module" <+> pretty modname <+> " is a parametric module." +-- +checkExp (IntLit x NoInfo loc) = do + t <- newTypeOverloaded loc "num" anyNumberType + pure $ IntLit x (Info t) loc +checkExp (FloatLit x NoInfo loc) = do + t <- newTypeOverloaded loc "float" anyFloatType + pure $ FloatLit x (Info t) loc +checkExp (Literal v loc) = + pure $ Literal v loc +checkExp (StringLit vs loc) = + pure $ StringLit vs loc +-- No need to type check this, as these are only produced by the +-- parser if the elements are monomorphic and all match. +checkExp (ArrayVal vs t loc) = + pure $ ArrayVal vs t loc +checkExp (ArrayLit es _ loc) = do + -- TODO: this will produce an enormous number of constraints and + -- type variables for pathologically large arrays with + -- type-unsuffixed integers. Add some special case that handles that + -- more efficiently. + et <- newElemType loc "et" NoUniqueness + es' <- forM es $ \e -> do + e' <- checkExp e + e_t <- expType e' + et' <- asType et + ctEq (Reason (locOf loc)) e_t et' + pure e' + let arr_t = arrayOf (Shape [sizeFromInteger (L.genericLength es) loc]) et + pure $ ArrayLit es' (Info arr_t) loc +checkExp (RecordLit fs loc) = + RecordLit <$> evalStateT (mapM checkField fs) mempty <*> pure loc + where + checkField (RecordFieldExplicit f e rloc) = do + errIfAlreadySet (unLoc f) rloc + modify $ M.insert (unLoc f) rloc + RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc + checkField (RecordFieldImplicit name NoInfo rloc) = do + errIfAlreadySet (baseName (unLoc name)) rloc + t <- lift $ asStructType =<< lookupVar rloc (qualName (unLoc name)) + modify $ M.insert (baseName (unLoc name)) rloc + pure $ RecordFieldImplicit name (Info t) rloc + + errIfAlreadySet f rloc = do + maybe_sloc <- gets $ M.lookup f + case maybe_sloc of + Just sloc -> + lift . typeError rloc mempty $ + "Field" + <+> dquotes (pretty f) + <+> "previously defined at" + <+> pretty (locStrRel rloc sloc) + <> "." + Nothing -> pure () + +-- +checkExp (Attr info e loc) = + Attr <$> checkAttr info <*> checkExp e <*> pure loc +checkExp (Assert e1 e2 NoInfo loc) = do + e1' <- require "being asserted" [Bool] =<< checkExp e1 + e2' <- checkExp e2 + pure $ Assert e1' e2' (Info (prettyText e1)) loc +-- +checkExp (Constr name es NoInfo loc) = do + es' <- mapM checkExp es + es_ts <- mapM expType es' + t <- newTypeWithConstr loc "t" NoUniqueness name es_ts + pure $ Constr name es' (Info t) loc +-- +checkExp (AppExp (Apply fe args loc) NoInfo) = do + fe' <- checkExp fe + (args', apply_args) <- + fmap NE.unzip . forM args $ \(_, arg) -> do + arg' <- checkExp arg + arg_t <- expType arg' + pure (arg', (Just arg', frameOf arg', arg_t)) + fe_t <- expType fe' + (rt, ams) <- checkApply loc fname (frameOf fe', fe_t) apply_args + rt' <- asStructType rt + let args'' = + NE.zipWith (\am arg -> (Info (Nothing, am), arg)) ams args' + pure $ AppExp (Apply fe' args'' loc) $ Info (AppRes rt' []) + where + fname = + case fe of + Var v _ _ -> Just v + _ -> Nothing +checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do + ftype <- lookupVar oploc op + e1' <- checkExp e1 + e1_t <- expType e1' + e2' <- checkExp e2 + e2_t <- expType e2' + + (rt, ams) <- + checkApply + loc + (Just op) + (mempty, ftype) + ((Just e1', frameOf e1', e1_t) NE.:| [(Just e2', frameOf e2', e2_t)]) + rt' <- asStructType rt + let (am1 NE.:| [am2]) = ams + + ftype' <- asStructType ftype + pure $ + AppExp + (BinOp (op, oploc) (Info ftype') (e1', Info (Nothing, am1)) (e2', Info (Nothing, am2)) loc) + (Info (AppRes rt' [])) +-- +checkExp (OpSectionLeft op _ e _ _ loc) = do + optype <- lookupVar loc op + e' <- checkExp e + e_t <- expType e' + t2 <- newType loc Lifted "t" NoUniqueness + t2' <- asStructType t2 + let f1 = frameOf e' + (rt, ams) <- + checkApply + loc + (Just op) + (mempty, optype) + ((Just e', f1, e_t) NE.:| [(Nothing, mempty, t2)]) + rt' <- asStructType rt + + let (am1 NE.:| _) = ams + t1 <- asStructType e_t + optype' <- asStructType optype + pure $ + OpSectionLeft + op + (Info optype') + e' + ( Info (Unnamed, toParam Observe t1, Nothing, am1), + Info (Unnamed, toParam Observe t2') + ) + (Info (RetType [] (rt' `setUniqueness` Nonunique)), Info []) + loc +checkExp (OpSectionRight op _ e _ NoInfo loc) = do + optype <- lookupVar loc op + e' <- checkExp e + e_t <- expType e' + t1 <- newType loc Lifted "t" NoUniqueness + t1' <- asStructType t1 + let f2 = frameOf e' + (rt, ams) <- + checkApply + loc + (Just op) + (mempty, optype) + ((Nothing, mempty, t1) NE.:| [(Just e', f2, e_t)]) + rt' <- asStructType rt + let (_ NE.:| [am2]) = ams + t2 <- asStructType e_t + + optype' <- asStructType optype + pure $ + OpSectionRight + op + (Info optype') + e' + -- Dummy types. + ( Info (Unnamed, toParam Observe t1'), + Info (Unnamed, toParam Observe t2, Nothing, am2) + ) + (Info $ RetType [] (rt' `setUniqueness` Nonunique)) + loc +-- +checkExp (ProjectSection fields NoInfo loc) = do + a <- newType loc Lifted "a" NoUniqueness + b <- newType loc Lifted "b" NoUniqueness + mustHaveFields loc a fields b + ft <- asStructType $ Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique + pure $ ProjectSection fields (Info ft) loc +-- +checkExp (Lambda params body retdecl NoInfo loc) = do + bindParams [] params $ \params' -> do + body' <- checkExp body + + (body_t, retdecl') <- checkRetDecl body' retdecl + body_t' <- asStructType body_t + let ret = RetType [] $ toRes Nonunique body_t' + pure $ Lambda params' body' retdecl' (Info ret) loc +-- +checkExp (AppExp (LetPat sizes pat e body loc) _) = do + e' <- checkExp e + e_t <- expType e' + + bindSizes sizes . incLevel . bindLetPat pat e_t $ \pat' -> do + body' <- incLevel $ checkExp body + body_t <- expType body' + + body_t' <- asStructType body_t + pure $ + AppExp + (LetPat sizes (fmap toStruct pat') e' body' loc) + (Info $ AppRes body_t' []) +-- +checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) = do + (tparams', params', retdecl', rettype, e') <- + bindParams tparams params $ \params' -> do + e' <- checkExp e + (e_t, retdecl') <- checkRetDecl e' retdecl + pure (tparams, params', retdecl', fmap (const Nonunique) e_t, e') + + params'' <- mapM (traverse asType) params' + + let entry = BoundV tparams' $ funType params'' $ RetType [] rettype + bindF scope = + scope + { scopeVtable = M.insert name entry $ scopeVtable scope + } + body' <- localScope bindF $ checkExp body + body_t <- expType body' + + body_t' <- asStructType body_t + rettype' <- asStructType rettype + pure $ + AppExp + ( LetFun + name + (tparams', params', retdecl', Info (RetType [] rettype'), e') + body' + loc + ) + (Info $ AppRes body_t' []) +-- +checkExp (AppExp (Range start maybe_step end loc) _) = do + start' <- require "use in range expression" anyIntType =<< checkExp start + let check e = do + e' <- checkExp e + start_t <- expType start' + e_t <- expType e' + ctEq (Reason (locOf e')) start_t e_t + pure e' + maybe_step' <- traverse check maybe_step + end' <- traverse check end + range_t <- newElemType loc "range" NoUniqueness + range_t' <- asType range_t + start_t <- expType start' + ctEq (Reason (locOf start')) range_t' (arrayOfRank 1 start_t) + pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] +-- +checkExp (Project k e NoInfo loc) = do + e' <- checkExp e + kt <- newType loc Lifted "kt" NoUniqueness + t <- newTypeWithField loc "t" k kt + e_t <- expType e' + ctEq (Reason (locOf e')) e_t t + kt' <- asStructType kt + pure $ Project k e' (Info kt') loc +-- +checkExp (RecordUpdate src fields ve NoInfo loc) = do + src' <- checkExp src + src_t <- expType src' + ve' <- checkExp ve + ve_t <- expType ve' + mustHaveFields loc src_t fields ve_t + src_t' <- asStructType src_t + pure $ RecordUpdate src' fields ve' (Info src_t') loc +-- +checkExp (IndexSection slice NoInfo loc) = do + slice' <- checkSlice slice + index_arg_t <- newElemType loc "index" NoUniqueness + index_elem_t <- newElemType loc "index_elem" NoUniqueness + index_res_t <- newElemType loc "index_res" NoUniqueness + let num_slices = length $ filter isSlice slice + ctEq (Reason (locOf loc)) index_arg_t $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf loc)) index_res_t $ arrayOfRank (length slice) index_elem_t + ft <- asStructType $ Scalar $ Arrow mempty Unnamed Observe index_arg_t $ second (const Nonunique) $ RetType [] index_res_t + pure $ IndexSection slice' (Info ft) loc +-- +checkExp (AppExp (Index e slice loc) _) = do + e' <- checkExp e + e_t <- expType e' + slice' <- checkSlice slice + index_tv <- newTyVar loc Unlifted "index" + index_elem_t <- newElemType loc "index_elem" NoUniqueness + let num_slices = length $ filter isSlice slice + ctEq (Reason (locOf loc)) (tyVarType NoUniqueness index_tv) $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf e')) e_t $ arrayOfRank (length slice) index_elem_t + pure $ AppExp (Index e' slice' loc) (Info $ AppRes (tyVarType NoUniqueness index_tv) []) +-- +checkExp (Update src slice ve loc) = do + src' <- checkExp src + src_t <- expType src' + slice' <- checkSlice slice + ve' <- checkExp ve + ve_t <- expType ve' + let num_slices = length $ filter isSlice slice + update_elem_t <- newElemType loc "update_elem" NoUniqueness + ctEq (Reason (locOf src')) src_t $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t + pure $ Update src' slice' ve' loc +-- +checkExp (AppExp (LetWith dest src slice ve body loc) _) = do + src_t <- lookupVar (srclocOf src) $ qualName $ identName src + src_t' <- asStructType src_t + let src' = src {identType = Info src_t'} + dest' = dest {identType = Info src_t'} + slice' <- checkSlice slice + ve' <- checkExp ve + ve_t <- expType ve' + let num_slices = length $ filter isSlice slice + update_elem_t <- newElemType loc "update_elem" NoUniqueness + ctEq (Reason (locOf loc)) src_t $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t + bind [dest'] $ do + body' <- checkExp body + body_t <- expType body' + body_t' <- asStructType body_t + pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t' []) +-- +checkExp (AppExp (If e1 e2 e3 loc) _) = do + e1' <- checkExp e1 + e1_t <- expType e1' + e2' <- checkExp e2 + e2_t <- expType e2' + e3' <- checkExp e3 + e3_t <- expType e3' + if_t <- newType loc SizeLifted "if_t" NoUniqueness + + ctEq (Reason (locOf e1')) e1_t (Scalar (Prim Bool)) + ctEq (ReasonBranches (locOf loc) e2_t e3_t) e2_t if_t + ctEq (ReasonBranches (locOf loc) e2_t e3_t) e3_t if_t + + if_t' <- asStructType if_t + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes if_t' []) +-- +checkExp (AppExp (Match e cs loc) _) = do + e' <- checkExp e + e_t <- expType e' + (cs', t) <- checkCases e_t cs + + match_t <- newType loc SizeLifted "match_t" NoUniqueness + ctEq (Reason (locOf loc)) match_t t + + match_t' <- asStructType match_t + pure $ AppExp (Match e' cs' loc) (Info $ AppRes match_t' []) +-- +checkExp (AppExp (Loop _ pat arg form body loc) _) = do + arg' <- checkExp $ case arg of + LoopInitExplicit e -> e + LoopInitImplicit _ -> + -- Should have been filled out in Names + error "Unspected LoopInitImplicit" + arg_t <- expType arg' + bindLetPat pat arg_t $ \pat' -> do + (form', body') <- + case form of + For (Ident i _ iloc) bound -> do + bound' <- require "loop bound" anyIntType =<< checkExp bound + bound_t <- expType bound' + bound_t' <- asStructType bound_t + let i' = Ident i (Info bound_t') iloc + bind [i'] $ do + body' <- checkExp body + pure (For i' bound', body') + While cond -> do + cond' <- checkExp cond + body' <- checkExp body + pure (While cond', body') + ForIn elemp arr -> do + arr' <- checkExp arr + elem_t <- newElemType elemp "elem" NoUniqueness + arr_t <- expType arr' + elem_t' <- asType elem_t + ctEq (Reason (locOf arr')) arr_t $ arrayOfRank 1 elem_t' + bindLetPat elemp elem_t' $ \elemp' -> do + body' <- checkExp body + pure (ForIn (toStruct <$> elemp') arr', body') + body_t <- expType body' + ctEq (Reason (locOf loc)) arg_t body_t + pure $ + AppExp + (Loop [] pat' (LoopInitExplicit arg') form' body' loc) + (Info (AppRes (patternStructType pat') [])) +-- +checkExp (Ascript e te loc) = do + e' <- checkExp e + (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te + e_t <- expType e' + st' <- asType st + ctEq (ReasonAscription (locOf e') (toStruct st') (toStruct e_t)) e_t st' + pure $ Ascript e' te' loc +checkExp (Coerce e te NoInfo loc) = do + e' <- checkExp e + (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te + e_t <- expType e' + st' <- asType st + ctEq (Reason (locOf e')) e_t st' + pure $ Coerce e' te' (Info (toStruct st)) loc + +doDefault :: + [VName] -> + VName -> + Either [PrimType] (TypeBase () NoUniqueness) -> + TermM (TypeBase () NoUniqueness) +doDefault tyvars_at_toplevel v (Left pts) + | Signed Int32 `elem` pts = do + when (v `elem` tyvars_at_toplevel) $ + warn usage "Defaulting ambiguous type to i32." + pure $ Scalar $ Prim $ Signed Int32 + | FloatType Float64 `elem` pts = do + when (v `elem` tyvars_at_toplevel) $ + warn usage "Defaulting ambiguous type to f64." + pure $ Scalar $ Prim $ FloatType Float64 + | otherwise = + typeError usage mempty . withIndexLink "ambiguous-type" $ + "Type is ambiguous (could be one of" + <+> commasep (map pretty pts) + <> ")." + "Add a type annotation to disambiguate the type." + where + usage = mkUsage NoLoc "overload" +doDefault _ _ (Right t) = pure t + +-- | Apply defaults on otherwise ambiguous types. This may result in +-- some type variables becoming known, so we have to perform +-- substitutions on the RHS of the substitutions afterwards. +doDefaults :: + [VName] -> + M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -> + TermM (M.Map TyVar (TypeBase () NoUniqueness)) +doDefaults tyvars_at_toplevel substs = do + substs' <- M.traverseWithKey (doDefault tyvars_at_toplevel) substs + pure $ M.map (substTyVars (`M.lookup` substs')) substs' + +generalise :: + TypeBase () NoUniqueness -> + [UnconTyVar] -> + Solution -> + ([TypeParam], [VName]) +generalise fun_t unconstrained solution = + -- Candidates for let-generalisation are those type variables that + -- are used in fun_t. + let visible = foldMap expandTyVars $ typeVars fun_t + onTyVar (v, l) + | v `S.member` visible = Left $ TypeParamType l v mempty + | otherwise = Right v + in partitionEithers $ map onTyVar unconstrained + where + expandTyVars v = + case M.lookup v solution of + Just (Right t) -> foldMap expandTyVars $ typeVars t + _ -> S.singleton v + +generaliseAndDefaults :: + [UnconTyVar] -> + Solution -> + TypeBase () NoUniqueness -> + TermM ([TypeParam], M.Map VName (TypeBase () NoUniqueness)) +generaliseAndDefaults unconstrained solution t = do + let (generalised, unconstrained') = + generalise t unconstrained solution + solution' <- doDefaults (map typeParamName generalised) solution + pure + ( generalised, + -- See #1552 for why we resolve unconstrained and + -- un-generalised type variables to (). + M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' + ) + +checkValDef :: + ( VName, + Maybe (TypeExp (ExpBase NoInfo VName) VName), + [TypeParam], + [PatBase NoInfo VName ParamType], + ExpBase NoInfo VName, + SrcLoc + ) -> + TypeM + ( Either TypeError ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), + [Pat ParamType], + Maybe (TypeExp Exp VName), + Exp + ) +checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do + (params', body', retdecl') <- + bindParams tparams params $ \params' -> do + body' <- checkExp body + (_, retdecl') <- checkRetDecl body' retdecl + pure (params', body', retdecl') + + cts <- gets termConstraints + tyvars <- gets termTyVars + typarams <- gets termTyParams + artificial <- gets termArtificial + + debugTraceM 3 $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + + debugTraceM 3 $ + unlines + [ "## cts:", + unlines $ map prettyString cts, + "## body:", + prettyString body', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, + "## artificial:", + unlines $ map (\(v, t) -> prettyNameString v <> " => " <> prettyString t) (M.toList artificial) + ] + + onRankSolution typarams + =<< rankAnalysis1 loc cts tyvars artificial params' body' retdecl' + where + onRankSolution typarams ((cts', artificial, tyvars'), params', body'', retdecl') = do + solution <- + bitraverse + pure + (fmap (second (onArtificial artificial)) . onTySolution params' body'') + $ solve (reverse cts') typarams tyvars' + debugTraceM 3 $ + unlines + [ "## constraints:", + unlines $ map prettyString cts', + "## typarams:", + let f (lvl, l, _) = (lvl, l) + in unlines (map (prettyString . bimap prettyNameString f) (M.toList typarams)), + "## tyvars':", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either (docString . prettyTypeError) (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## generalised:" :) . map prettyString . fst) solution + ] + pure (solution, params', retdecl', body'') + + onTySolution params' body' (unconstrained, solution) = do + body_t <- expType body' + let fun_t = + foldFunType + (map (first (const ()) . patternType) params') + (RetType [] $ bimap (const ()) (const Nonunique) body_t) + generaliseAndDefaults unconstrained solution fun_t + + onArtificial artificial solution = + M.map (substTyVars (`M.lookup` solution) . first (const ())) artificial <> solution + +checkSingleExp :: + ExpBase NoInfo VName -> + TypeM (Either TypeError ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), Exp) +checkSingleExp e = runTermM $ do + e' <- checkExp e + cts <- gets termConstraints + tyvars <- gets termTyVars + typarams <- gets termTyParams + artificial <- gets termArtificial + ((cts', _artificial', tyvars'), _, e'', _) <- + rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' Nothing + case solve cts' typarams tyvars' of + Left err -> pure (Left err, e'') + Right (unconstrained, solution) -> do + e_t <- expType e'' + x <- generaliseAndDefaults unconstrained solution $ first (const ()) e_t + pure (Right x, e'') + +-- | Type-check a single size expression in isolation. This expression may +-- turn out to be polymorphic, in which case it is unified with i64. +checkSizeExp :: + ExpBase NoInfo VName -> + TypeM (Either TypeError ([UnconTyVar], M.Map TyVar (TypeBase () NoUniqueness)), Exp) +checkSizeExp e = runTermM $ do + e' <- checkSizeExp' e + cts <- gets termConstraints + tyvars <- gets termTyVars + typarams <- gets termTyParams + artificial <- gets termArtificial + + (cts_tyvars', _, es', _) <- + L.unzip4 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' Nothing + + solutions <- + forM cts_tyvars' $ \(cts', _artificial', tyvars') -> + bitraverse pure (traverse (doDefaults mempty)) $ solve cts' typarams tyvars' + + case (solutions, es') of + ([solution], [e'']) -> + pure (solution, e'') + _ -> pure (Left $ TypeError (locOf e) mempty "Ambiguous size expression", e') diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 1bcb62b39c..096551cd4d 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -8,6 +8,7 @@ module Language.Futhark.TypeChecker.Types TypeSubs, Substitutable (..), substTypesAny, + substTyVars, -- * Witnesses mustBeExplicitInType, @@ -59,7 +60,7 @@ mustBeExplicitInBinding :: StructType -> S.Set VName mustBeExplicitInBinding bind_t = let (ts, ret) = unfoldFunType bind_t alsoRet = M.unionWith (&&) $ M.fromList $ map (,True) (S.toList (fvVars (freeInType ret))) - in S.fromList $ M.keys $ M.filter id $ alsoRet $ L.foldl' onType mempty $ map toStruct ts + in S.fromList $ M.keys $ M.filter id $ alsoRet $ L.foldl' onType mempty $ map (toStruct . snd) ts where onType uses t = uses <> mustBeExplicitAux t -- Left-biased union. @@ -533,6 +534,26 @@ substTypesAny lookupSubst ot = toAny d = d in first toAny ot' +-- | Substitution without caring about sizes. +substTyVars :: (Monoid u) => (VName -> Maybe (TypeBase d NoUniqueness)) -> TypeBase d u -> TypeBase d u +substTyVars f (Scalar (TypeVar u qn args)) = + case f $ qualLeaf qn of + Just t' -> second (const mempty) $ substTyVars f t' + Nothing -> Scalar (TypeVar u qn (map onArg args)) + where + onArg (TypeArgType t) = TypeArgType $ substTyVars f t + onArg (TypeArgDim e) = TypeArgDim e +substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt +substTyVars f (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars f) fs +substTyVars f (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars f) cs +substTyVars f (Scalar (Arrow u pname d t1 (RetType ext t2))) = + Scalar $ + Arrow u pname d (substTyVars f t1) $ + RetType ext $ + substTyVars f t2 `setUniqueness` uniqueness t2 +substTyVars f (Array u shape elemt) = + arrayOfWithAliases u shape $ substTyVars f $ Scalar elemt + -- Note [AnySize] -- -- Consider a program: diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 5638acdc9e..5a5fb42cd8 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -12,15 +12,8 @@ module Language.Futhark.TypeChecker.Unify RigidSource (..), BreadCrumbs, sizeFree, - noBreadCrumbs, - hasNoBreadCrumbs, dimNotes, - zeroOrderType, arrayElemType, - mustHaveConstr, - mustHaveField, - mustBeOneOf, - equalityType, normType, normTypeFully, unify, @@ -43,57 +36,10 @@ import Futhark.Util (topologicalSort) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.Traversals +import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad hiding (BoundV) import Language.Futhark.TypeChecker.Types --- | A piece of information that describes what process the type --- checker currently performing. This is used to give better error --- messages for unification errors. -data BreadCrumb - = MatchingTypes StructType StructType - | MatchingFields [Name] - | MatchingConstructor Name - | Matching (Doc ()) - -instance Pretty BreadCrumb where - pretty (MatchingTypes t1 t2) = - "When matching type" - indent 2 (pretty t1) - "with" - indent 2 (pretty t2) - pretty (MatchingFields fields) = - "When matching types of record field" - <+> dquotes (mconcat $ punctuate "." $ map pretty fields) - <> dot - pretty (MatchingConstructor c) = - "When matching types of constructor" <+> dquotes (pretty c) <> dot - pretty (Matching s) = - unAnnotate s - --- | Unification failures can occur deep down inside complicated types --- (consider nested records). We leave breadcrumbs behind us so we --- can report the path we took to find the mismatch. -newtype BreadCrumbs = BreadCrumbs [BreadCrumb] - --- | An empty path. -noBreadCrumbs :: BreadCrumbs -noBreadCrumbs = BreadCrumbs [] - --- | Is the path empty? -hasNoBreadCrumbs :: BreadCrumbs -> Bool -hasNoBreadCrumbs (BreadCrumbs xs) = null xs - --- | Drop a breadcrumb on the path behind you. -breadCrumb :: BreadCrumb -> BreadCrumbs -> BreadCrumbs -breadCrumb (MatchingFields xs) (BreadCrumbs (MatchingFields ys : bcs)) = - BreadCrumbs $ MatchingFields (ys ++ xs) : bcs -breadCrumb bc (BreadCrumbs bcs) = - BreadCrumbs $ bc : bcs - -instance Pretty BreadCrumbs where - pretty (BreadCrumbs []) = mempty - pretty (BreadCrumbs bcs) = line <> stack (map pretty bcs) - -- | A usage that caused a type constraint. data Usage = Usage (Maybe T.Text) Loc deriving (Show) @@ -124,10 +70,6 @@ data Constraint = NoConstraint Liftedness Usage | ParamType Liftedness Loc | Constraint StructRetType Usage - | Overloaded [PrimType] Usage - | HasFields Liftedness (M.Map Name StructType) Usage - | Equality Usage - | HasConstrs Liftedness (M.Map Name [StructType]) Usage | ParamSize Loc | -- | Is not actually a type, but a term-level size, -- possibly already set to something specific. @@ -143,10 +85,6 @@ instance Located Constraint where locOf (NoConstraint _ usage) = locOf usage locOf (ParamType _ usage) = locOf usage locOf (Constraint _ usage) = locOf usage - locOf (Overloaded _ usage) = locOf usage - locOf (HasFields _ _ usage) = locOf usage - locOf (Equality usage) = locOf usage - locOf (HasConstrs _ _ usage) = locOf usage locOf (ParamSize loc) = locOf loc locOf (Size _ usage) = locOf usage locOf (UnknownSize loc _) = locOf loc @@ -235,7 +173,7 @@ prettySource ctx loc (RigidOutOfScope boundloc v) = <> pretty (locStrRel ctx boundloc) <> "." prettySource _ _ RigidUnify = - "is an artificial size invented during unification of functions with anonymous sizes." + textwrap "is an artificial size invented during unification of functions with anonymous sizes." prettySource ctx loc (RigidCond t1 t2) = "is unknown due to conditional expression at " <> pretty (locStrRel ctx loc) @@ -266,27 +204,6 @@ typeNotes ctx = . fvVars . freeInType -typeVarNotes :: (MonadUnify m) => VName -> m Notes -typeVarNotes v = maybe mempty (note . snd) . M.lookup v <$> getConstraints - where - note (HasConstrs _ cs _) = - aNote $ - prettyName v - <+> "=" - <+> hsep (map ppConstr (M.toList cs)) - <+> "..." - note (Overloaded ts _) = - aNote $ prettyName v <+> "must be one of" <+> mconcat (punctuate ", " (map pretty ts)) - note (HasFields _ fs _) = - aNote $ - prettyName v - <+> "=" - <+> braces (mconcat (punctuate ", " (map ppField (M.toList fs)))) - note _ = mempty - - ppConstr (c, _) = "#" <> pretty c <+> "..." <+> "|" - ppField (f, _) = prettyName f <> ":" <+> "..." - -- | Monads that which to perform unification must implement this type -- class. class (Monad m) => MonadUnify m where @@ -352,12 +269,6 @@ unsharedConstructorsMsg cs1 cs2 = filter (`notElem` M.keys cs1) (M.keys cs2) ++ filter (`notElem` M.keys cs2) (M.keys cs1) --- | Is the given type variable the name of an abstract type or type --- parameter, which we cannot substitute? -isRigid :: VName -> Constraints -> Bool -isRigid v constraints = - maybe True (rigidConstraint . snd) $ M.lookup v constraints - -- | If the given type variable is nonrigid, what is its level? isNonRigid :: VName -> Constraints -> Maybe Level isNonRigid v constraints = do @@ -368,10 +279,6 @@ isNonRigid v constraints = do type UnifySizes m = BreadCrumbs -> [VName] -> (VName -> Maybe Int) -> Exp -> Exp -> m () -flipUnifySizes :: UnifySizes m -> UnifySizes m -flipUnifySizes onDims bcs bound nonrigid t1 t2 = - onDims bcs bound nonrigid t2 t1 - unifyWith :: (MonadUnify m) => UnifySizes m -> @@ -396,14 +303,7 @@ unifyWith onDims usage = subunify False failure = matchError (srclocOf usage) mempty bcs t1' t2' - link ord' = - linkVarToType linkDims usage bound bcs - where - -- We may have to flip the order of future calls to - -- onDims inside linkVarToType. - linkDims - | ord' = flipUnifySizes onDims - | otherwise = onDims + link = linkVarToType usage bound bcs unifyTypeArg bcs' (TypeArgDim d1) (TypeArgDim d2) = onDims' bcs' (swap ord d1 d2) @@ -443,24 +343,24 @@ unifyWith onDims usage = subunify False ) | tn == arg_tn, length targs == length arg_targs -> do - let bcs' = breadCrumb (Matching "When matching type arguments.") bcs + let bcs' = matching "When matching type arguments." <> bcs zipWithM_ (unifyTypeArg bcs') targs arg_targs ( Scalar (TypeVar _ (QualName [] v1) []), Scalar (TypeVar _ (QualName [] v2) []) ) -> case (nonrigid v1, nonrigid v2) of (Nothing, Nothing) -> failure - (Just lvl1, Nothing) -> link ord v1 lvl1 t2' - (Nothing, Just lvl2) -> link (not ord) v2 lvl2 t1' + (Just lvl1, Nothing) -> link v1 lvl1 t2' + (Nothing, Just lvl2) -> link v2 lvl2 t1' (Just lvl1, Just lvl2) - | lvl1 <= lvl2 -> link ord v1 lvl1 t2' - | otherwise -> link (not ord) v2 lvl2 t1' + | lvl1 <= lvl2 -> link v1 lvl1 t2' + | otherwise -> link v2 lvl2 t1' (Scalar (TypeVar _ (QualName [] v1) []), _) | Just lvl <- nonrigid v1 -> - link ord v1 lvl t2' + link v1 lvl t2' (_, Scalar (TypeVar _ (QualName [] v2) [])) | Just lvl <- nonrigid v2 -> - link (not ord) v2 lvl t1' + link v2 lvl t1' ( Scalar (Arrow _ p1 d1 a1 (RetType b1_dims b1)), Scalar (Arrow _ p2 d2 a2 (RetType b2_dims b2)) ) @@ -495,13 +395,13 @@ unifyWith onDims usage = subunify False subunify (not ord) bound - (breadCrumb (Matching "When matching parameter types.") bcs) + (matching "When matching parameter types." <> bcs) a1 a2 subunify ord bound' - (breadCrumb (Matching "When matching return types.") bcs) + (matching "When matching return types." <> bcs) (toStruct b1') (toStruct b2') @@ -557,7 +457,7 @@ unifySizes usage bcs bound nonrigid e1 (Var v2 _ _) not (anyBound bound e1) || (qualLeaf v2 `elem` bound) = linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 unifySizes usage bcs _ _ e1 e2 = do - notes <- (<>) <$> dimNotes usage e2 <*> dimNotes usage e2 + notes <- (<>) <$> dimNotes usage e1 <*> dimNotes usage e2 unifyError usage notes bcs $ "Sizes" <+> dquotes (pretty e1) @@ -567,7 +467,7 @@ unifySizes usage bcs _ _ e1 e2 = do -- | Unifies two types. unify :: (MonadUnify m) => Usage -> StructType -> StructType -> m () -unify usage = unifyWith (unifySizes usage) usage mempty noBreadCrumbs +unify usage = unifyWith (unifySizes usage) usage mempty mempty occursCheck :: (MonadUnify m) => @@ -691,7 +591,6 @@ sizeFree tloc expKiller orig_t = do linkVarToType :: (MonadUnify m) => - UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> @@ -699,7 +598,7 @@ linkVarToType :: Level -> StructType -> m () -linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do +linkVarToType usage bound bcs vn lvl tp_unnorm = do -- We have to expand anyway for the occurs check, so we might as -- well link the fully expanded type. tp <- normTypeFully tp_unnorm @@ -726,14 +625,13 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do M.insert vn (lvl, Constraint (RetType (ext_new <> ext_witnessed) tp') usage) let unliftedBcs unlifted_usage = - breadCrumb - ( Matching $ - "When verifying that" - <+> dquotes (prettyName vn) - <+> textwrap "is not instantiated with a function type, due to" - <+> pretty unlifted_usage + matching + ( "When verifying that" + <+> dquotes (prettyName vn) + <+> textwrap "is not instantiated with a function type, due to" + <+> pretty unlifted_usage ) - bcs + <> bcs constraints <- getConstraints case snd <$> M.lookup vn constraints of @@ -748,125 +646,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do <+> "cannot be instantiated with type containing anonymous sizes:" indent 2 (pretty tp) textwrap "This is usually because the size of an array returned by a higher-order function argument cannot be determined statically. This can also be due to the return size being a value parameter. Add type annotation to clarify." - Just (Equality _) -> do - link - equalityType usage tp - Just (Overloaded ts old_usage) - | tp `notElem` map (Scalar . Prim) ts -> do - link - case tp of - Scalar (TypeVar _ (QualName [] v) []) - | not $ isRigid v constraints -> - linkVarToTypes usage v ts - _ -> - unifyError usage mempty bcs $ - "Cannot instantiate" - <+> dquotes (prettyName vn) - <+> "with type" - indent 2 (pretty tp) - "as" - <+> dquotes (prettyName vn) - <+> "must be one of" - <+> commasep (map pretty ts) - "due to" - <+> pretty old_usage - <> "." - Just (HasFields l required_fields old_usage) -> do - when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp - case tp of - Scalar (Record tp_fields) - | all (`M.member` tp_fields) $ M.keys required_fields -> do - required_fields' <- mapM normTypeFully required_fields - let tp' = Scalar $ Record $ required_fields <> tp_fields -- Crucially left-biased. - ext = filter (`S.member` fvVars (freeInType tp')) bound - modifyConstraints $ - M.insert vn (lvl, Constraint (RetType ext tp') usage) - unifySharedFields onDims usage bound bcs required_fields' tp_fields - Scalar (TypeVar _ (QualName [] v) []) -> do - case M.lookup v constraints of - Just (_, HasFields _ tp_fields _) -> - unifySharedFields onDims usage bound bcs required_fields tp_fields - Just (_, NoConstraint {}) -> pure () - Just (_, Equality {}) -> pure () - _ -> do - notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v - noRecordType notes - link - modifyConstraints $ - M.insertWith - combineFields - v - (lvl, HasFields l required_fields old_usage) - where - combineFields (_, HasFields l1 fs1 usage1) (_, HasFields l2 fs2 _) = - (lvl, HasFields (l1 `min` l2) (M.union fs1 fs2) usage1) - combineFields hasfs _ = hasfs - _ -> - unifyError usage mempty bcs $ - "Cannot instantiate" - <+> dquotes (prettyName vn) - <+> "with type" - indent 2 (pretty tp) - "as" - <+> dquotes (prettyName vn) - <+> "must be a record with fields" - indent 2 (pretty (Record required_fields)) - "due to" - <+> pretty old_usage - <> "." - -- See Note [Linking variables to sum types] - Just (HasConstrs l required_cs old_usage) -> do - when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp - case tp of - Scalar (Sum ts) - | all (`M.member` ts) $ M.keys required_cs -> do - let tp' = Scalar $ Sum $ required_cs <> ts -- Crucially left-biased. - ext = filter (`S.member` fvVars (freeInType tp')) bound - modifyConstraints $ - M.insert vn (lvl, Constraint (RetType ext tp') usage) - unifySharedConstructors onDims usage bound bcs required_cs ts - | otherwise -> - unsharedConstructors required_cs ts =<< typeVarNotes vn - Scalar (TypeVar _ (QualName [] v) []) -> do - case M.lookup v constraints of - Just (_, HasConstrs _ v_cs _) -> - unifySharedConstructors onDims usage bound bcs required_cs v_cs - Just (_, NoConstraint {}) -> pure () - Just (_, Equality {}) -> pure () - _ -> do - notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v - noSumType notes - link - modifyConstraints $ - M.insertWith - combineConstrs - v - (lvl, HasConstrs l required_cs old_usage) - where - combineConstrs (_, HasConstrs l1 cs1 usage1) (_, HasConstrs l2 cs2 _) = - (lvl, HasConstrs (l1 `min` l2) (M.union cs1 cs2) usage1) - combineConstrs hasCs _ = hasCs - _ -> noSumType =<< typeVarNotes vn _ -> link - where - unsharedConstructors cs1 cs2 notes = - unifyError - usage - notes - bcs - (unsharedConstructorsMsg cs1 cs2) - noSumType notes = - unifyError - usage - notes - bcs - "Cannot unify a sum type with a non-sum type." - noRecordType notes = - unifyError - usage - notes - bcs - "Cannot unify a record type with a non-record type." linkVarToDim :: (MonadUnify m) => @@ -912,138 +692,6 @@ linkVarToDim usage bcs vn lvl e = do _ -> modifyConstraints $ M.insert dim' (lvl, c) checkVar _ _ = pure () --- | Assert that this type must be one of the given primitive types. -mustBeOneOf :: (MonadUnify m) => [PrimType] -> Usage -> StructType -> m () -mustBeOneOf [req_t] usage t = unify usage (Scalar (Prim req_t)) t -mustBeOneOf ts usage t = do - t' <- normType t - constraints <- getConstraints - let isRigid' v = isRigid v constraints - - case t' of - Scalar (TypeVar _ (QualName [] v) []) - | not $ isRigid' v -> linkVarToTypes usage v ts - Scalar (Prim pt) | pt `elem` ts -> pure () - _ -> failure - where - failure = - unifyError usage mempty noBreadCrumbs $ - "Cannot unify type" - <+> dquotes (pretty t) - <+> "with any of " - <> commasep (map pretty ts) - <> "." - -linkVarToTypes :: (MonadUnify m) => Usage -> VName -> [PrimType] -> m () -linkVarToTypes usage vn ts = do - vn_constraint <- M.lookup vn <$> getConstraints - case vn_constraint of - Just (lvl, Overloaded vn_ts vn_usage) -> - case ts `L.intersect` vn_ts of - [] -> - unifyError usage mempty noBreadCrumbs $ - "Type constrained to one of" - <+> commasep (map pretty ts) - <+> "but also one of" - <+> commasep (map pretty vn_ts) - <+> "due to" - <+> pretty vn_usage - <> "." - ts' -> modifyConstraints $ M.insert vn (lvl, Overloaded ts' usage) - Just (_, HasConstrs _ _ vn_usage) -> - unifyError usage mempty noBreadCrumbs $ - "Type constrained to one of" - <+> commasep (map pretty ts) - <> ", but also inferred to be sum type due to" - <+> pretty vn_usage - <> "." - Just (_, HasFields _ _ vn_usage) -> - unifyError usage mempty noBreadCrumbs $ - "Type constrained to one of" - <+> commasep (map pretty ts) - <> ", but also inferred to be record due to" - <+> pretty vn_usage - <> "." - Just (lvl, _) -> modifyConstraints $ M.insert vn (lvl, Overloaded ts usage) - Nothing -> - unifyError usage mempty noBreadCrumbs $ - "Cannot constrain type to one of" <+> commasep (map pretty ts) - --- | Assert that this type must support equality. -equalityType :: - (MonadUnify m, Pretty (Shape dim), Pretty u) => - Usage -> - TypeBase dim u -> - m () -equalityType usage t = do - unless (orderZero t) $ - unifyError usage mempty noBreadCrumbs $ - "Type " <+> dquotes (pretty t) <+> "does not support equality (may contain function)." - mapM_ mustBeEquality $ typeVars t - where - mustBeEquality vn = do - constraints <- getConstraints - case M.lookup vn constraints of - Just (_, Constraint (RetType [] (Scalar (TypeVar _ (QualName [] vn') []))) _) -> - mustBeEquality vn' - Just (_, Constraint (RetType _ vn_t) cusage) - | not $ orderZero vn_t -> - unifyError usage mempty noBreadCrumbs $ - "Type" - <+> dquotes (pretty t) - <+> "does not support equality." - "Constrained to be higher-order due to" - <+> pretty cusage - <+> "." - | otherwise -> pure () - Just (lvl, NoConstraint _ _) -> - modifyConstraints $ M.insert vn (lvl, Equality usage) - Just (_, Overloaded _ _) -> - pure () -- All primtypes support equality. - Just (_, Equality {}) -> - pure () - _ -> - unifyError usage mempty noBreadCrumbs $ - "Type" <+> prettyName vn <+> "does not support equality." - -zeroOrderTypeWith :: - (MonadUnify m) => - Usage -> - BreadCrumbs -> - StructType -> - m () -zeroOrderTypeWith usage bcs t = do - unless (orderZero t) $ - unifyError usage mempty bcs $ - "Type" indent 2 (pretty t) "found to be functional." - mapM_ mustBeZeroOrder . S.toList . typeVars =<< normType t - where - mustBeZeroOrder vn = do - constraints <- getConstraints - case M.lookup vn constraints of - Just (lvl, NoConstraint _ _) -> - modifyConstraints $ M.insert vn (lvl, NoConstraint Unlifted usage) - Just (lvl, HasFields _ fs _) -> - modifyConstraints $ M.insert vn (lvl, HasFields Unlifted fs usage) - Just (lvl, HasConstrs _ cs _) -> - modifyConstraints $ M.insert vn (lvl, HasConstrs Unlifted cs usage) - Just (_, ParamType Lifted ploc) -> - unifyError usage mempty bcs $ - "Type parameter" - <+> dquotes (prettyName vn) - <+> "at" - <+> pretty (locStr ploc) - <+> "may be a function." - _ -> pure () - --- | Assert that this type must be zero-order. -zeroOrderType :: - (MonadUnify m) => Usage -> T.Text -> StructType -> m () -zeroOrderType usage desc = - zeroOrderTypeWith usage $ breadCrumb bc noBreadCrumbs - where - bc = Matching $ "When checking" <+> textwrap desc - arrayElemTypeWith :: (MonadUnify m, Pretty (Shape dim), Pretty u) => Usage -> @@ -1079,9 +727,7 @@ arrayElemType :: TypeBase dim u -> m () arrayElemType usage desc = - arrayElemTypeWith usage $ breadCrumb bc noBreadCrumbs - where - bc = Matching $ "When checking" <+> textwrap desc + arrayElemTypeWith usage $ matching $ "When checking" <+> textwrap desc unifySharedFields :: (MonadUnify m) => @@ -1094,7 +740,7 @@ unifySharedFields :: m () unifySharedFields onDims usage bound bcs fs1 fs2 = forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (t1, t2)) -> - unifyWith onDims usage bound (breadCrumb (MatchingFields [f]) bcs) t1 t2 + unifyWith onDims usage bound (matchingField f <> bcs) t1 t2 unifySharedConstructors :: (MonadUnify m) => @@ -1111,102 +757,12 @@ unifySharedConstructors onDims usage bound bcs cs1 cs2 = where unifyConstructor c f1 f2 | length f1 == length f2 = do - let bcs' = breadCrumb (MatchingConstructor c) bcs + let bcs' = matchingConstructor c <> bcs zipWithM_ (unifyWith onDims usage bound bcs') f1 f2 | otherwise = unifyError usage mempty bcs $ "Cannot unify constructor" <+> dquotes (prettyName c) <> "." --- | In @mustHaveConstr usage c t fs@, the type @t@ must have a --- constructor named @c@ that takes arguments of types @ts@. -mustHaveConstr :: - (MonadUnify m) => - Usage -> - Name -> - StructType -> - [StructType] -> - m () -mustHaveConstr usage c t fs = do - constraints <- getConstraints - case t of - Scalar (TypeVar _ (QualName _ tn) []) - | Just (lvl, NoConstraint l _) <- M.lookup tn constraints -> do - mapM_ (scopeCheck usage noBreadCrumbs tn lvl) fs - modifyConstraints $ M.insert tn (lvl, HasConstrs l (M.singleton c fs) usage) - | Just (lvl, HasConstrs l cs _) <- M.lookup tn constraints -> - case M.lookup c cs of - Nothing -> - modifyConstraints $ - M.insert tn (lvl, HasConstrs l (M.insert c fs cs) usage) - Just fs' - | length fs == length fs' -> zipWithM_ (unify usage) fs fs' - | otherwise -> - unifyError usage mempty noBreadCrumbs $ - "Different arity for constructor" <+> dquotes (pretty c) <> "." - Scalar (Sum cs) -> - case M.lookup c cs of - Nothing -> - unifyError usage mempty noBreadCrumbs $ - "Constuctor" <+> dquotes (pretty c) <+> "not present in type." - Just fs' - | length fs == length fs' -> zipWithM_ (unify usage) fs fs' - | otherwise -> - unifyError usage mempty noBreadCrumbs $ - "Different arity for constructor" <+> dquotes (pretty c) <+> "." - _ -> - unify usage t $ Scalar $ Sum $ M.singleton c fs - -mustHaveFieldWith :: - (MonadUnify m) => - UnifySizes m -> - Usage -> - [VName] -> - BreadCrumbs -> - Name -> - StructType -> - m StructType -mustHaveFieldWith onDims usage bound bcs l t = do - constraints <- getConstraints - l_type <- newTypeVar (locOf usage) "t" - case t of - Scalar (TypeVar _ (QualName _ tn) []) - | Just (lvl, NoConstraint {}) <- M.lookup tn constraints -> do - scopeCheck usage bcs tn lvl l_type - modifyConstraints $ M.insert tn (lvl, HasFields Lifted (M.singleton l l_type) usage) - pure l_type - | Just (lvl, HasFields lifted fields _) <- M.lookup tn constraints -> do - case M.lookup l fields of - Just t' -> unifyWith onDims usage bound bcs l_type t' - Nothing -> - modifyConstraints $ - M.insert - tn - (lvl, HasFields lifted (M.insert l l_type fields) usage) - pure l_type - Scalar (Record fields) - | Just t' <- M.lookup l fields -> do - unify usage l_type t' - pure t' - | otherwise -> - unifyError usage mempty bcs $ - "Attempt to access field" - <+> dquotes (pretty l) - <+> " of value of type" - <+> pretty (toStructural t) - <> "." - _ -> do - unify usage t $ Scalar $ Record $ M.singleton l l_type - pure l_type - --- | Assert that some type must have a field with this name and type. -mustHaveField :: - (MonadUnify m) => - Usage -> - Name -> - StructType -> - m StructType -mustHaveField usage = mustHaveFieldWith (unifySizes usage) usage mempty noBreadCrumbs - newDimOnMismatch :: (MonadUnify m) => Loc -> @@ -1242,10 +798,31 @@ unifyMostCommon :: StructType -> m (StructType, [VName]) unifyMostCommon usage t1 t2 = do - -- We are ignoring the dimensions here, because any mismatches - -- should be turned into fresh size variables. - let allOK _ _ _ _ _ = pure () - unifyWith allOK usage mempty noBreadCrumbs t1 t2 + -- Like 'unifySizes', except we do not fail on mismatches - these + -- are instead turned into fresh existential sizes in + -- 'newDimOnMismatch'. The most annoying thing is that we have to + -- replicate scope checking, because we don't want to link if it + -- would fail. + constraints <- getConstraints + + let varLevel v = fst <$> M.lookup v constraints + expLevel e = + L.foldl' max 0 $ mapMaybe varLevel $ S.toList $ fvVars $ freeInExp e + + onDims bcs bound nonrigid e1 e2 + | Just es <- similarExps e1 e2 = + mapM_ (uncurry $ onDims bcs bound nonrigid) es + onDims bcs _ nonrigid (Var v1 _ _) e2 + | Just lvl1 <- nonrigid (qualLeaf v1), + expLevel e2 < lvl1 = + linkVarToDim usage bcs (qualLeaf v1) lvl1 e2 + onDims bcs _ nonrigid e1 (Var v2 _ _) + | Just lvl2 <- nonrigid (qualLeaf v2), + expLevel e1 < lvl2 = + linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 + onDims _ _ _ _ _ = pure () + + unifyWith onDims usage mempty mempty t1 t2 t1' <- normTypeFully t1 t2' <- normTypeFully t2 newDimOnMismatch (locOf usage) t1' t2' @@ -1332,22 +909,3 @@ doUnification loc rigid_tparams nonrigid_tparams t1 t2 = runUnifyM rigid_tparams nonrigid_tparams $ do unify (Usage Nothing (locOf loc)) t1 t2 normTypeFully t2 - --- Note [Linking variables to sum types] --- --- Consider the case when unifying a result type --- --- i32 -> ?[n].(#foo [n]bool) --- --- with --- --- i32 -> ?[k].a --- --- where 'a' has a HasConstrs constraint saying that it must have at --- least a constructor of type '#foo [0]bool'. --- --- This unification should succeed, but we must not merely link 'a' to --- '#foo [n]bool', as 'n' is not free. Instead we should instantiate --- 'a' to be a concrete sum type (because now we know exactly which --- constructor labels it must have), and unify each of its constructor --- payloads with the corresponding expected payload. diff --git a/tests/ad/stripmine2.fut b/tests/ad/stripmine2.fut index 1e654969d2..ce6e32a3ef 100644 --- a/tests/ad/stripmine2.fut +++ b/tests/ad/stripmine2.fut @@ -1,7 +1,7 @@ def pow_list [n] y (xs :[n]i32) = #[stripmine(2)] loop accs = (replicate n 1) for _i < y do - map2 (*) accs xs + map2 (*) accs xs -- == -- entry: prim diff --git a/tests/ascription0.fut b/tests/ascription0.fut index 5aff8c054a..8c3a50e026 100644 --- a/tests/ascription0.fut +++ b/tests/ascription0.fut @@ -3,6 +3,6 @@ -- == -- error: match -def main(x: i32, y:i32): i32 = +def main(x: i32, y:i32): (bool,bool) = let (((a): i32), b: i32) : (bool,bool) = (x,y) in (a,b) diff --git a/tests/automap/ambiguous0.fut b/tests/automap/ambiguous0.fut new file mode 100644 index 0000000000..8c1ec556c3 --- /dev/null +++ b/tests/automap/ambiguous0.fut @@ -0,0 +1,4 @@ +-- == +-- error: ambiguous + +def ambig (xss : [][]i32) = i64.sum (length xss) diff --git a/tests/automap/bool1.fut b/tests/automap/bool1.fut new file mode 100644 index 0000000000..f3fe08213e --- /dev/null +++ b/tests/automap/bool1.fut @@ -0,0 +1,6 @@ +-- == +-- entry: f +-- input { [true, true, false] [false, true, true] } +-- output { [true, true, true] } + +def f [m] (xs: [m]bool) (ys: [m]bool) = xs || ys diff --git a/tests/automap/combinations.fut b/tests/automap/combinations.fut new file mode 100644 index 0000000000..7d77e85abb --- /dev/null +++ b/tests/automap/combinations.fut @@ -0,0 +1,38 @@ +-- All the various ways one can imagine automapping a very simple program. + +def plus (x: i32) (y: i32) = x + y + +-- == +-- entry: vecint +-- input { [1,2,3] } output { [3,4,5] } + +entry vecint (x: []i32) = plus x 2 + +-- == +-- entry: vecvec +-- input { [1,2,3] } output { [2,4,6] } + +entry vecvec (x: []i32) = plus x x + +-- == +-- entry: matint +-- input { [[1,2],[3,4]] } output { [[3,4],[5,6]] } + +entry matint (x: [][]i32) = plus x 2 + +-- == +-- entry: matmat +-- input { [[1,2],[3,4]] } output { [[2,4],[6,8]] } + +entry matmat (x: [][]i32) = plus x x + +-- == +-- entry: matvec +-- input { [[1,2],[3,4]] [5,6] } output { [[6,8],[8,10]] } + +entry matvec (x: [][]i32) (y: []i32) = plus x y + +-- == +-- entry: vecvecvec +-- input { [1,2,3] } output { [3,6,9] } +entry vecvecvec (x: []i32) = (\x y z -> x + y + z) x x x diff --git a/tests/automap/equality1.fut b/tests/automap/equality1.fut new file mode 100644 index 0000000000..b2a173f30d --- /dev/null +++ b/tests/automap/equality1.fut @@ -0,0 +1,23 @@ +-- == +-- entry: bigger_to_smaller +-- input { [[1,2],[3,4]] [1,2] } +-- output { [[true, true], [false, false]] } + +-- == +-- entry: smaller_to_bigger +-- input { [[1,2],[3,4]] [1,2] } +-- output { [[true, true], [false, false]] } + +-- == +-- entry: smaller_to_bigger2 +-- input { [[1,2],[3,4]] 1 } +-- output { [[true,false],[false,false]]} + +entry bigger_to_smaller [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = + xss == ys + +entry smaller_to_bigger [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = + ys == xss + +entry smaller_to_bigger2 [n] (xss : [n][n]i32) (z: i32) : [n][n]bool = + z == xss diff --git a/tests/automap/lambda.fut b/tests/automap/lambda.fut new file mode 100644 index 0000000000..1bb7ed26e3 --- /dev/null +++ b/tests/automap/lambda.fut @@ -0,0 +1,6 @@ +-- == +-- entry: main +-- random input { [10]f32 [10]f32 } + +entry main [n](xs: [n]f32) (ys: [n]f32): [n]f32 = + map2 (*) xs ys diff --git a/tests/automap/leetcode.fut b/tests/automap/leetcode.fut new file mode 100644 index 0000000000..43a50cb2b8 --- /dev/null +++ b/tests/automap/leetcode.fut @@ -0,0 +1,4 @@ +def outerprod f x y = map (f >-> flip map y) x +def bidd A = outerprod (==) (indices A) (indices A) +def xmat A = bidd A || reverse (bidd A) +def check_matrix (A : [][]i32) = xmat A == (A != 0) |> flatten |> and diff --git a/tests/automap/map0.fut b/tests/automap/map0.fut new file mode 100644 index 0000000000..a5ab0887ae --- /dev/null +++ b/tests/automap/map0.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [0,1,2,3] } +-- output { [1,2,3,4] } + +def automap 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = f as + +entry main (x: []i32) = automap (+1) x diff --git a/tests/automap/mri-q-qr.fut b/tests/automap/mri-q-qr.fut new file mode 100644 index 0000000000..8004f7da5d --- /dev/null +++ b/tests/automap/mri-q-qr.fut @@ -0,0 +1,2 @@ +def qr [numX][numK] (expArgs : [numX][numK]f32) (phiMag : [numK]f32) : [numX]f32 = + f32.sum (f32.cos expArgs * phiMag) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut new file mode 100644 index 0000000000..270e18195a --- /dev/null +++ b/tests/automap/mri-q.fut @@ -0,0 +1,41 @@ +-- == +-- entry: main +-- random input { [12]f32 [12]f32 [12]f32 [10]f32 [10]f32 [10]f32 [12]f32 [12]f32 } +-- output { true } + +def main_orig [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) + : ([numX]f32, [numX]f32) = + let phiMag = map2 (\r i -> r*r + i*i) phiR phiI + let expArgs = map3 (\x_e y_e z_e -> + map (2.0f32*f32.pi*) + (map3 (\kx_e ky_e kz_e -> + kx_e * x_e + ky_e * y_e + kz_e * z_e) + kx ky kz)) + x y z + let qr = map1 (map f32.cos >-> map2 (*) phiMag >-> f32.sum) expArgs + let qi = map1 (map f32.sin >-> map2 (*) phiMag >-> f32.sum) expArgs + in (qr, qi) + +def main_am [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) + : ([numX]f32, [numX]f32) = + let phiMag = phiR * phiR + phiI * phiI + let expArgs = map3 (\x_e y_e z_e -> + 2.0*f32.pi*(kx*x_e + ky*y_e + kz*z_e)) + x y z + let qr = f32.sum (f32.cos expArgs * phiMag) + let qi = f32.sum (f32.sin expArgs * phiMag) + in (qr, qi) + +entry main [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) = + let (qr, qi) = main_orig kx ky kz x y z phiR phiI + let (qr_am, qi_am) = main_am kx ky kz x y z phiR phiI + in and (qr == qr_am && qi == qi_am) diff --git a/tests/automap/operator1.fut b/tests/automap/operator1.fut new file mode 100644 index 0000000000..464a8b79c4 --- /dev/null +++ b/tests/automap/operator1.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [10,20] } +-- output { [[11, 22],[13, 24]] } + +def (+^) [n] (xs: [n]i32) (ys: [n]i32) : [n]i32 = xs + ys + +--entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = +-- xss +^ ys diff --git a/tests/automap/optionpricing.fut b/tests/automap/optionpricing.fut new file mode 100644 index 0000000000..c4c916521f --- /dev/null +++ b/tests/automap/optionpricing.fut @@ -0,0 +1,78 @@ +-- == +-- entry: sobolIndR +-- random input { [12][10]i32 i32 } +-- output { true } + +-- == +-- entry: sobolRecI +-- random input { [12][10]i32 [12]i32 i32} +-- output { true } + +-- == +-- entry: sobolReci2 +-- random input { [12][10]i32 [12]i32 i32} +-- output { true } + +def grayCode(x: i32): i32 = (x >> 1) ^ x + +def testBit(n: i32, ind: i32): bool = + let t = (1 << ind) in (n & t) == t + +def xorInds [num_bits] (n: i32) (dir_vs: [num_bits]i32): i32 = + let reldv_vals = map (\(dv: i32, i): i32 -> + if testBit(grayCode(n),i32.i64 i) + then dv else 0 + ) (zip (dir_vs) (iota(num_bits)) ) in + reduce (^) 0 (reldv_vals ) + + +def sobolIndI [len] (dir_vs: [len][]i32, n: i32 ): [len]i32 = + map (xorInds(n)) (dir_vs ) + +def index_of_least_significant_0(num_bits: i32, n: i32): i32 = + let (goon,k) = (true,0) in + let (_,k,_) = loop ((goon,k,n)) for i < num_bits do + if(goon) + then if (n & 1) == 1 + then (true, k+1, n>>1) + else (false,k, n ) + else (false,k, n ) + in k + +def recM [len][num_bits] (sob_dirs: [len][num_bits]i32, i: i32 ): [len]i32 = + let bit= index_of_least_significant_0(i32.i64 num_bits,i) in + map (\(row: []i32): i32 -> row[bit]) (sob_dirs ) + +def sobolIndR_orig [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = + let divisor = 2.0 ** f32.i64(num_bits) + let arri = map (xorInds n) dir_vs + in map (\x -> f32.i32(x) / divisor) arri + +def sobolRecI_orig [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = + let bit = index_of_least_significant_0(i32.i64 num_bits, x) + in map2 (\vct_row prev -> vct_row[bit] ^ prev) sob_dir_vs prev + +def sobolReci2_orig [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= + let col = recM(sob_dirs, i) + in map2 (^) prev col + +def sobolIndR_am [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = + let divisor = 2.0 ** f32.i64(num_bits) + let arri = xorInds n dir_vs + in f32.i32 arri / divisor + +def sobolRecI_am [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = + let bit = index_of_least_significant_0(i32.i64 num_bits, x) + in sob_dir_vs[:,bit] ^ prev + +def sobolReci2_am [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= + prev ^ recM(sob_dirs, i) + +entry sobolIndR [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): bool = + and (sobolIndR_orig dir_vs n == sobolIndR_am dir_vs n) + +entry sobolRecI [num_bits][n] (sob_dir_vs: [n][num_bits]i32) (prev: [n]i32) (x: i32): bool = + and (sobolRecI_orig (sob_dir_vs, prev, x) == sobolRecI_am (sob_dir_vs, prev, x)) + +entry sobolReci2 [n][num_bits] (sob_dirs: [n][num_bits]i32) (prev: [n]i32) (i: i32): bool = + and (sobolReci2_orig (sob_dirs, prev, i) == sobolReci2_am (sob_dirs, prev, i)) diff --git a/tests/automap/pagerank.fut b/tests/automap/pagerank.fut new file mode 100644 index 0000000000..3552990144 --- /dev/null +++ b/tests/automap/pagerank.fut @@ -0,0 +1,18 @@ +-- == +-- entry: calculate_dangling_ranks +-- random input { [12]f32 [12]i32} +-- output { true } + +def calculate_dangling_ranks_orig [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = + let zipped = zip sizes ranks + let weights = map (\(size, rank) -> if size == 0 then rank else 0f32) zipped + let total = f32.sum weights / f32.i64 n + in map (+total) ranks + +def calculate_dangling_ranks_am [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = + let weights = f32.bool (sizes == 0) * ranks + let total = f32.sum weights / f32.i64 n + in ranks + total + +entry calculate_dangling_ranks [n] (ranks: [n]f32) (sizes: [n]i32): bool = + and (calculate_dangling_ranks_orig ranks sizes == calculate_dangling_ranks_am ranks sizes) diff --git a/tests/automap/project.fut b/tests/automap/project.fut new file mode 100644 index 0000000000..2902d0565a --- /dev/null +++ b/tests/automap/project.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [1,2,3] [4,5,6] } +-- output { [1,2,3,4,5,6] } + +entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = + let xsys = zip xs ys + in xsys.0 ++ xsys.1 + diff --git a/tests/automap/projsec1.fut b/tests/automap/projsec1.fut new file mode 100644 index 0000000000..485c977bc5 --- /dev/null +++ b/tests/automap/projsec1.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [1,2,3] [4,5,6] } +-- output { [1,2,3,4,5,6] } + +entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = + let xsys = zip xs ys + in (.0) xsys ++ (.1) xsys + diff --git a/tests/automap/same_typevar.fut b/tests/automap/same_typevar.fut new file mode 100644 index 0000000000..260a00b785 --- /dev/null +++ b/tests/automap/same_typevar.fut @@ -0,0 +1,16 @@ +-- == +-- tags { no_wasm } +-- entry: big_to_small +-- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } + +-- == +-- entry: small_to_big +-- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } + +def f 'a (x: a) (y: a) (z: a) = (x, y, z) + +entry big_to_small [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = + f xss ys z + +entry small_to_big [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = + f z ys xss diff --git a/tests/automap/sgemm.fut b/tests/automap/sgemm.fut new file mode 100644 index 0000000000..a31ce0188e --- /dev/null +++ b/tests/automap/sgemm.fut @@ -0,0 +1,32 @@ +-- == +-- entry: main +-- random input { [5][10]f32 [10][3]f32 [5][3]f32 f32 f32 } +-- output { true } + +def mult_orig [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = + let dotprod xs ys = f32.sum (map2 (*) xs ys) + in map (\xs -> map (dotprod xs) (transpose yss)) xss + +def add [n][m] (xss: [n][m]f32, yss: [n][m]f32): [n][m]f32 = + map2 (map2 (+)) xss yss + +def scale [n][m] (xss: [n][m]f32, a: f32): [n][m]f32 = + map (map1 (*a)) xss + +def main_orig [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) + : [n][p]f32 = + add(scale(css,beta), scale(mult_orig(ass,bss), alpha)) + + +def mult_am [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = + f32.sum ((transpose (replicate p xss)) * (replicate n (transpose yss))) + +def main_am [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) + : [n][p]f32 = + css*beta + mult_am(ass,bss)*alpha + +entry main [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) = + and (and (main_orig ass bss css alpha beta == main_am ass bss css alpha beta)) diff --git a/tests/automap/simple1.fut b/tests/automap/simple1.fut new file mode 100644 index 0000000000..f8833bb3b6 --- /dev/null +++ b/tests/automap/simple1.fut @@ -0,0 +1,7 @@ +-- == +-- entry: main +-- input { [1,2] 10 } +-- output { [11, 12] } + +entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = + xs + y diff --git a/tests/automap/simple2.fut b/tests/automap/simple2.fut new file mode 100644 index 0000000000..ac57abcbe0 --- /dev/null +++ b/tests/automap/simple2.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [1,1] } +-- output { [[2,3],[4,5]] } + +entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = + xss + ys + diff --git a/tests/automap/simple3.fut b/tests/automap/simple3.fut new file mode 100644 index 0000000000..adc60bd43f --- /dev/null +++ b/tests/automap/simple3.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [1,1] } +-- output { [[2,3],[4,5]] } + +entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = + ys + xss + diff --git a/tests/automap/simple4.fut b/tests/automap/simple4.fut new file mode 100644 index 0000000000..d94bbe4a6b --- /dev/null +++ b/tests/automap/simple4.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { 3 [1,1] [[1,2],[3,4]] } +-- output { [[5,6],[7,8]] } + +entry main [n] (x : i32) (ys: [n]i32) (zss : [n][n]i32) : [n][n]i32 = + x + ys + zss + diff --git a/tests/automap/simple5.fut b/tests/automap/simple5.fut new file mode 100644 index 0000000000..46610e6567 --- /dev/null +++ b/tests/automap/simple5.fut @@ -0,0 +1,6 @@ +-- == +-- input { [1,2,3] 4 } +-- output { [5, 6, 7] } + +entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = + (\x y -> x + y) xs y diff --git a/tests/issue1787.fut b/tests/issue1787.fut index 90cb01dd72..ed4aef3fae 100644 --- a/tests/issue1787.fut +++ b/tests/issue1787.fut @@ -1,5 +1,5 @@ -- == --- error: found to be functional +-- error: function type entry main: i32 -> i32 -> i32 = ((true, (.0)), (false, (.1))) diff --git a/tests/issue514.fut b/tests/issue514.fut index 2f70eca04f..057d69b71a 100644 --- a/tests/issue514.fut +++ b/tests/issue514.fut @@ -1,4 +1,4 @@ -- == --- error: issue514.fut:4:26-36 +-- error: issue514.fut:4:13-22 def main = (2.0 + 3.0) / (2 + 3i32) diff --git a/tests/record-update6.fut b/tests/record-update6.fut index 53349ac0ab..fef501e1a2 100644 --- a/tests/record-update6.fut +++ b/tests/record-update6.fut @@ -1,10 +1,9 @@ -- Inference of record in lambda. -- == --- error: Full type of type octnode = {body: i32} -def f (octree: []octnode) (i: i32) = +entry f (octree: []octnode) (i: i32) = map (\n -> if n.body != i then n else n with body = 0) octree diff --git a/tests/shapes/error4.fut b/tests/shapes/error4.fut index b842bdf44a..cf75bfe897 100644 --- a/tests/shapes/error4.fut +++ b/tests/shapes/error4.fut @@ -2,7 +2,7 @@ -- == -- error: Sizes.*"n".*do not match -def f (g: (n: i64) -> [n]i32) (l: i64): i32 = +def f (g: (n: i64) -> [n]i64) (l: i64): i64 = (g l)[0] def main = f (\n : []i64 -> iota (n+1)) diff --git a/tests/shapes/error6.fut b/tests/shapes/error6.fut index 3fda73dd6e..5c7332d94a 100644 --- a/tests/shapes/error6.fut +++ b/tests/shapes/error6.fut @@ -2,7 +2,7 @@ -- == -- error: "n" -def ap (f: (n: i64) -> [n]i32) (k: i64) : [k]i32 = +def ap (f: (n: i64) -> [n]i64) (k: i64) : [k]i64 = f k def main = ap (\n -> iota (n+1)) 10 diff --git a/tests/shapes/polymorphic4.fut b/tests/shapes/polymorphic4.fut index b44af86c34..acab851f67 100644 --- a/tests/shapes/polymorphic4.fut +++ b/tests/shapes/polymorphic4.fut @@ -2,6 +2,6 @@ -- == -- error: do not match -def foo f x : [1]i32 = +def foo (f : (n: i64) -> [n]i32) x : [1]i32 = let r = if true then f x : []i32 else [1i32] in r diff --git a/tests/shapes/shape_duplicate.fut b/tests/shapes/shape_duplicate.fut index 3bbd5f391f..b29e1e7cbe 100644 --- a/tests/shapes/shape_duplicate.fut +++ b/tests/shapes/shape_duplicate.fut @@ -4,7 +4,7 @@ -- == -- error: do not match -def f [n][m] ((_, elems: [n]i32): (i32,[m]i32)): i32 = +def f [n][m] ((_, elems: [n]i64): (i64,[m]i64)): i64 = n + m + elems[0] -def main (x: i32, y: []i32): i32 = f (x, y) +def main (x: i64, y: []i64): i64 = f (x, y) diff --git a/tests/shapes/size-inference2.fut b/tests/shapes/size-inference2.fut index b6f59d4a9a..2804383f72 100644 --- a/tests/shapes/size-inference2.fut +++ b/tests/shapes/size-inference2.fut @@ -2,4 +2,4 @@ -- == -- error: Sizes.*do not match -def main [n] (xs: [n]i32) : [n]i32 = iota (length xs) +def main [n] (xs: [n]i32) : [n]i64 = iota (length xs) diff --git a/tests/sumtypes/coerce1.fut b/tests/sumtypes/coerce1.fut index eeff92a2a3..b6bfe42f3d 100644 --- a/tests/sumtypes/coerce1.fut +++ b/tests/sumtypes/coerce1.fut @@ -1,5 +1,4 @@ -- == --- error: Ambiguous size.*anonymous size type opt 't = #some t | #none diff --git a/tests/tridag.fut b/tests/tridag.fut index a055dca86a..e8cc6718e8 100644 --- a/tests/tridag.fut +++ b/tests/tridag.fut @@ -34,32 +34,31 @@ -- } -def tridag(nn: i32, - b: *[]f64, d: *[]f64, - a: []f64, c: []f64 ): ([]f64,[]f64) = - if (nn == 1) +def tridag [nn] (b: *[]f64, d: *[nn]f64, + a: []f64, c: []f64 ): ([]f64,[]f64) = + if (nn == 1) --then ( b, map(\f64 (f64 x, f64 y) -> x / y, d, b) ) then (b, [d[0]/b[0]]) - else - let (b,d) = loop((b, d)) for i < (nn-1) do - let xm = a[i+1] / b[i] - let b[i+1] = b[i+1] - xm*c[i] - let d[i+1] = d[i+1] - xm*d[i] in - (b, d) + else + let (b,d) = loop((b, d)) for i < (nn-1) do + let xm = a[i+1] / b[i] + let b[i+1] = b[i+1] - xm*c[i] + let d[i+1] = d[i+1] - xm*d[i] in + (b, d) - let d[nn-1] = d[nn-1] / b[nn-1] in + let d[nn-1] = d[nn-1] / b[nn-1] in - let d = loop(d) for i < (nn-1) do - let k = nn - 2 - i - let d[k] = ( d[k] - c[k]*d[k+1] ) / b[k] in - d - in (b, d) + let d = loop(d) for i < (nn-1) do + let k = nn - 2 - i + let d[k] = ( d[k] - c[k]*d[k+1] ) / b[k] in + d + in (b, d) def main: ([]f64,[]f64) = - let nn = reduce (+) 0 ([1,2,3,4]) - let a = replicate nn 3.33 - let b = map (\x -> f64.i64(x) + 1.0) (iota(nn)) - let c = map (\x -> 1.11*f64.i64(x) + 0.5) (iota(nn)) - let d = map (\x -> 1.01*f64.i64(x) + 0.25) (iota(nn)) - in tridag(i32.i64 nn, b, d, a, c) + let nn = reduce (+) 0 ([1,2,3,4]) + let a = replicate nn 3.33 + let b = map (\x -> f64.i64(x) + 1.0) (iota(nn)) + let c = map (\x -> 1.11*f64.i64(x) + 0.5) (iota(nn)) + let d = map (\x -> 1.01*f64.i64(x) + 0.25) (iota(nn)) + in tridag(b, d, a, c) diff --git a/tests/types/inference-error4.fut b/tests/types/inference-error4.fut index 809b98302a..0ff781f33a 100644 --- a/tests/types/inference-error4.fut +++ b/tests/types/inference-error4.fut @@ -1,6 +1,6 @@ -- If something is used in a loop, it cannot later be inferred as a -- function. -- == --- error: functional +-- error: function type def f x = (loop x = x for i < 10 do x, x 2) diff --git a/tests/types/inference22.fut b/tests/types/inference22.fut index 4e367db82f..dbe574e411 100644 --- a/tests/types/inference22.fut +++ b/tests/types/inference22.fut @@ -2,5 +2,5 @@ -- == def main (x: i32) (y: bool) = - let f x y = (y,x) + let f 'a 'b (x: a) (y: b) = (y,x) in (f x y, f y x) diff --git a/tests/types/inference5.fut b/tests/types/inference5.fut deleted file mode 100644 index 900704f21a..0000000000 --- a/tests/types/inference5.fut +++ /dev/null @@ -1,7 +0,0 @@ --- Inference for a local function. --- == --- input { 2 } output { 4 } - -def main x = - let apply f x = f x - in apply (apply (i32.+) x) x diff --git a/tests/issue1599.fut b/tests/types/occurs.fut similarity index 53% rename from tests/issue1599.fut rename to tests/types/occurs.fut index 3ce47c38b1..c37b1448c4 100644 --- a/tests/issue1599.fut +++ b/tests/types/occurs.fut @@ -1,3 +1,4 @@ +-- Simple instance of an occurs check. -- == -- error: Occurs diff --git a/unittests/Futhark/Solve/BranchAndBoundTests.hs b/unittests/Futhark/Solve/BranchAndBoundTests.hs new file mode 100644 index 0000000000..b7e1bfe027 --- /dev/null +++ b/unittests/Futhark/Solve/BranchAndBoundTests.hs @@ -0,0 +1,143 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + +module Futhark.Solve.BranchAndBoundTests + ( tests, + ) +where + +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.BranchAndBound +import Futhark.Solve.LP +import Futhark.Solve.Matrix qualified as M +import Test.Tasty +import Test.Tasty.HUnit +import Prelude hiding (or) + +tests :: TestTree +tests = + testGroup + "BranchAndBoundTests" + [ -- testCase "1" $ + -- let lpe = + -- LPE + -- { pc = V.fromList [1, 1, 0, 0, 0], + -- pA = + -- M.fromLists + -- [ [-1, 1, 1, 0, 0], + -- [1, 0, 0, 1, 0], + -- [0, 1, 0, 0, 1] + -- ], + -- pd = V.fromList [1, 3, 2] + -- } + -- in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), + testCase "2" $ + let lp = + LP + { lpc = V.fromList [40, 30], + lpA = + M.fromLists + [ [1, 1], + [2, 1] + ], + lpd = V.fromList [12, 16] + } + in branchAndBound lp @?= Just (400 :: Double, V.fromList [4, 8]), + testCase "3" $ + let lp = + LP + { lpc = V.fromList [1, 2, 3], + lpA = + M.fromLists + [ [1, 1, 1], + [2, 1, 3] + ], + lpd = V.fromList [12, 18] + } + in branchAndBound lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), + testCase "4" $ + let lp = + LP + { lpc = V.fromList [5.5, 2.1], + lpA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + lpd = V.fromList [2, 17] + } + in assertBool (show $ branchAndBound lp) $ + case branchAndBound lp of + Nothing -> False + Just (z, sol) -> + (z `approxEq` (11.8 :: Double)) + && and (zipWith (==) (V.toList sol) [1, 3]), + -- testCase "5" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> oneIsZero ("b1", "x1") ("b2", "x2") + -- } + -- (lp, _idxmap) = linearProgToLP prog + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, _sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ], + -- testCase "6" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ] + + testCase "10" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "R2" ~+~ var "M3", + constraints = + [ var "artifical4" ~==~ constant 1 ~+~ var "t0", + constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", + var "b_R2" ~<=~ constant 1, + var "b_M3" ~<=~ constant 1, + var "R2" ~<=~ 1000 ~*~ var "b_R2", + var "M3" ~<=~ 1000 ~*~ var "b_M3", + var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 + ] + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ branchAndBound lp]) + $ case branchAndBound lp of + Nothing -> False + Just (z, _sol) -> + and + [ z `approxEq` (0 :: Double) + ] + ] + +approxEq :: (Fractional a, Ord a) => a -> a -> Bool +approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/unittests/Futhark/Solve/SimplexTests.hs b/unittests/Futhark/Solve/SimplexTests.hs new file mode 100644 index 0000000000..c29bd10a93 --- /dev/null +++ b/unittests/Futhark/Solve/SimplexTests.hs @@ -0,0 +1,221 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + +module Futhark.Solve.SimplexTests + ( tests, + ) +where + +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.LP +import Futhark.Solve.Matrix qualified as M +import Futhark.Solve.Simplex +import Test.Tasty +import Test.Tasty.HUnit +import Prelude hiding (or) + +tests :: TestTree +tests = + testGroup + "SimplexTests" + [ testCase "1" $ + let lpe = + LPE + { pc = V.fromList [1, 1, 0, 0, 0], + pA = + M.fromLists + [ [-1, 1, 1, 0, 0], + [1, 0, 0, 1, 0], + [0, 1, 0, 0, 1] + ], + pd = V.fromList [1, 3, 2] + } + in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), + testCase "2" $ + let lp = + LP + { lpc = V.fromList [40, 30], + lpA = + M.fromLists + [ [1, 1], + [2, 1] + ], + lpd = V.fromList [12, 16] + } + in simplexLP lp @?= Just (400 :: Double, V.fromList [4, 8]), + testCase "3" $ + let lp = + LP + { lpc = V.fromList [1, 2, 3], + lpA = + M.fromLists + [ [1, 1, 1], + [2, 1, 3] + ], + lpd = V.fromList [12, 18] + } + in simplexLP lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), + testCase "4" $ + let lp = + LP + { lpc = V.fromList [5.5, 2.1], + lpA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + lpd = V.fromList [2, 17] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + (z `approxEq` (14.08 :: Double)) + && and (zipWith approxEq (V.toList sol) [1.3, 3.3]), + testCase "5" $ + let lp = + LP + { lpc = V.fromList [0], + lpA = + M.fromLists + [ [1], + [-1] + ], + lpd = V.fromList [0, 0] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + (z `approxEq` (0 :: Double)) + && and (zipWith approxEq (V.toList sol) [0]), + testCase "6" $ + let lp = + LP + { lpc = V.fromList [1], + lpA = + M.fromLists + [ [1], + [-1] + ], + lpd = V.fromList [5, 5] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + z `approxEq` (5 :: Double) + && and (zipWith approxEq (V.toList sol) [5]), + testCase "7" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1", + constraints = + [ var "x1" ~<=~ 10 ~*~ var "b1", + var "b1" ~+~ var "b2" ~<=~ constant 1 + ] + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + (z `approxEq` (10 :: Double)) + && and (zipWith (==) (V.toList sol) [1, 0, 10]), + testCase "8" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1" ~+~ var "x2", + constraints = + [ var "x1" ~<=~ constant 10, + var "x2" ~<=~ constant 5 + ] + <> oneIsZero ("b1", "x1") ("b2", "x2") + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, _sol) -> + and + [ z `approxEq` (15 :: Double) + ], + -- testCase "9" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in trace + -- (unlines [show prog, show lp, show idxmap, show lpe]) + -- ( assertBool + -- (unlines [show $ simplexLP lp]) + -- $ case simplexLP lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (15 :: Double) + -- ] + -- ), + testCase "10" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "R2" ~+~ var "M3", + constraints = + [ var "artifical4" ~==~ constant 1 ~+~ var "t0", + constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", + var "b_R2" ~<=~ constant 1, + var "b_M3" ~<=~ constant 1, + var "R2" ~<=~ 1000 ~*~ var "b_R2", + var "M3" ~<=~ 1000 ~*~ var "b_M3", + var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 + ] + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, _sol) -> + and + [ z `approxEq` (0 :: Double) + ], + testCase "11" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "4R" ~+~ var "5M", + constraints = + [ var "6artifical" ~==~ constant 1 ~+~ var "2t", + constant 1 ~+~ var "3num" ~==~ constant 1 ~+~ var "2t", + var "0b_R" ~<=~ constant 1, + var "1b_M" ~<=~ constant 1, + var "4R" ~<=~ 1000 ~*~ var "0b_R", + var "5M" ~<=~ 1000 ~*~ var "1b_M", + var "0b_R" ~+~ var "1b_M" ~<=~ constant 1 + ] + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, _sol) -> + and + [ z `approxEq` (0 :: Double) + ] + ] + +approxEq :: (Fractional a, Ord a) => a -> a -> Bool +approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/unittests/futhark_tests.hs b/unittests/futhark_tests.hs index 10373e7c3f..f11596cb50 100644 --- a/unittests/futhark_tests.hs +++ b/unittests/futhark_tests.hs @@ -11,6 +11,8 @@ import Futhark.Internalise.TypesValuesTests qualified import Futhark.Optimise.ArrayLayoutTests qualified import Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests qualified import Futhark.Pkg.SolveTests qualified +import Futhark.Solve.BranchAndBoundTests qualified +import Futhark.Solve.SimplexTests qualified import Language.Futhark.PrimitiveTests qualified import Language.Futhark.SemanticTests qualified import Language.Futhark.SyntaxTests qualified @@ -35,6 +37,8 @@ allTests = Futhark.Analysis.AlgSimplifyTests.tests, Language.Futhark.TypeCheckerTests.tests, Language.Futhark.SemanticTests.tests, + Futhark.Solve.SimplexTests.tests, + Futhark.Solve.BranchAndBoundTests.tests, Futhark.Optimise.ArrayLayoutTests.tests ]