diff --git a/futhark.cabal b/futhark.cabal index 4882940dcd..c840f8d2d5 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -274,14 +274,14 @@ library Futhark.Construct Futhark.Doc.Generator Futhark.Error + Futhark.SoP.Convert + Futhark.SoP.Expression Futhark.SoP.Monad Futhark.SoP.Refine Futhark.SoP.RefineEquivs Futhark.SoP.FourierMotzkin - Futhark.SoP.PrimExp Futhark.SoP.RefineRanges Futhark.SoP.SoP - Futhark.SoP.ToFromSoP Futhark.SoP.Util Futhark.FreshNames Futhark.IR diff --git a/src/Futhark/Internalise/Refinement.hs b/src/Futhark/Internalise/Refinement.hs index 04d8908028..297e4cb733 100644 --- a/src/Futhark/Internalise/Refinement.hs +++ b/src/Futhark/Internalise/Refinement.hs @@ -7,12 +7,11 @@ import Futhark.Analysis.PrimExp (PrimExp) import Futhark.Analysis.PrimExp qualified as PE import Futhark.Internalise.TypesValues (internalisePrimType, internalisePrimValue) import Futhark.MonadFreshNames +import Futhark.SoP.Convert import Futhark.SoP.FourierMotzkin import Futhark.SoP.Monad -import Futhark.SoP.PrimExp import Futhark.SoP.Refine import Futhark.SoP.SoP -import Futhark.SoP.ToFromSoP import Futhark.SoP.Util import Futhark.Util.Pretty import Language.Futhark @@ -22,127 +21,25 @@ import Language.Futhark.Semantic hiding (Env) type Env = () newtype RefineM a - = RefineM (SoPMT VName (RWS Env () VNameSource) a) + = RefineM (SoPMT VName Exp (RWS Env () VNameSource) a) deriving ( Functor, Applicative, Monad, MonadReader Env, - MonadSoP VName + MonadSoP VName Exp ) instance MonadFreshNames RefineM where getNameSource = RefineM $ getNameSource putNameSource = RefineM . putNameSource -convertBinOp :: BinOp -> PrimExp VName -> PrimExp VName -> PrimType -> PrimType -> Maybe (PrimExp VName) -convertBinOp LogAnd x y Bool _ = - simpleBinOp PE.LogAnd x y -convertBinOp LogOr x y Bool _ = - simpleBinOp PE.LogOr x y -convertBinOp Plus x y (Signed t) _ = - simpleBinOp (PE.Add t PE.OverflowWrap) x y -convertBinOp Plus x y (Unsigned t) _ = - simpleBinOp (PE.Add t PE.OverflowWrap) x y -convertBinOp Plus x y (FloatType t) _ = - simpleBinOp (PE.FAdd t) x y -convertBinOp Minus x y (Signed t) _ = - simpleBinOp (PE.Sub t PE.OverflowWrap) x y -convertBinOp Minus x y (Unsigned t) _ = - simpleBinOp (PE.Sub t PE.OverflowWrap) x y -convertBinOp Minus x y (FloatType t) _ = - simpleBinOp (PE.FSub t) x y -convertBinOp Times x y (Signed t) _ = - simpleBinOp (PE.Mul t PE.OverflowWrap) x y -convertBinOp Times x y (Unsigned t) _ = - simpleBinOp (PE.Mul t PE.OverflowWrap) x y -convertBinOp Times x y (FloatType t) _ = - simpleBinOp (PE.FMul t) x y -convertBinOp Equal x y t _ = - simpleCmpOp (PE.CmpEq $ internalisePrimType t) x y -convertBinOp NotEqual x y t _ = do - Just $ PE.UnOpExp PE.Not $ PE.CmpOpExp (PE.CmpEq $ internalisePrimType t) x y -convertBinOp Less x y (Signed t) _ = - simpleCmpOp (PE.CmpSlt t) x y -convertBinOp Less x y (Unsigned t) _ = - simpleCmpOp (PE.CmpUlt t) x y -convertBinOp Leq x y (Signed t) _ = - simpleCmpOp (PE.CmpSle t) x y -convertBinOp Leq x y (Unsigned t) _ = - simpleCmpOp (PE.CmpUle t) x y -convertBinOp Greater x y (Signed t) _ = - simpleCmpOp (PE.CmpSlt t) y x -- Note the swapped x and y -convertBinOp Greater x y (Unsigned t) _ = - simpleCmpOp (PE.CmpUlt t) y x -- Note the swapped x and y -convertBinOp Geq x y (Signed t) _ = - simpleCmpOp (PE.CmpSle t) y x -- Note the swapped x and y -convertBinOp Geq x y (Unsigned t) _ = - simpleCmpOp (PE.CmpUle t) y x -- Note the swapped x and y -convertBinOp Less x y (FloatType t) _ = - simpleCmpOp (PE.FCmpLt t) x y -convertBinOp Leq x y (FloatType t) _ = - simpleCmpOp (PE.FCmpLe t) x y -convertBinOp Greater x y (FloatType t) _ = - simpleCmpOp (PE.FCmpLt t) y x -- Note the swapped x and y -convertBinOp Geq x y (FloatType t) _ = - simpleCmpOp (PE.FCmpLe t) y x -- Note the swapped x and y -convertBinOp Less x y Bool _ = - simpleCmpOp PE.CmpLlt x y -convertBinOp Leq x y Bool _ = - simpleCmpOp PE.CmpLle x y -convertBinOp Greater x y Bool _ = - simpleCmpOp PE.CmpLlt y x -- Note the swapped x and y -convertBinOp Geq x y Bool _ = - simpleCmpOp PE.CmpLle y x -- Note the swapped x and y -convertBinOp _ _ _ _ _ = Nothing - -simpleBinOp op x y = Just $ PE.BinOpExp op x y - -simpleCmpOp op x y = Just $ PE.CmpOpExp op x y - -expToPrimExp :: Exp -> Maybe (PrimExp VName) -expToPrimExp (Literal v _) = Just $ PE.ValueExp $ internalisePrimValue v -expToPrimExp (IntLit v (Info t) _) = - case t of - Scalar (Prim (Signed it)) -> Just $ PE.ValueExp $ PE.IntValue $ PE.intValue it v - Scalar (Prim (Unsigned it)) -> Just $ PE.ValueExp $ PE.IntValue $ PE.intValue it v - Scalar (Prim (FloatType ft)) -> Just $ PE.ValueExp $ PE.FloatValue $ PE.floatValue ft v - _ -> Nothing -expToPrimExp (FloatLit v (Info t) _) = - case t of - Scalar (Prim (FloatType ft)) -> Just $ PE.ValueExp $ PE.FloatValue $ PE.floatValue ft v - _ -> Nothing -expToPrimExp (AppExp (BinOp (op, _) _ (e_x, _) (e_y, _) _) _) = do - x <- expToPrimExp e_x - y <- expToPrimExp e_y - guard $ baseTag (qualLeaf op) <= maxIntrinsicTag - let name = baseString $ qualLeaf op - bop <- find ((name ==) . prettyString) [minBound .. maxBound :: BinOp] - t_x <- getPrimType $ typeOf e_x - t_y <- getPrimType $ typeOf e_y - convertBinOp bop x y t_x t_y - where - getPrimType (Scalar (Prim t)) = Just t - getPrimType _ = Nothing -expToPrimExp _ = Nothing - checkExp :: Exp -> RefineM Bool -checkExp e = - case expToPrimExp e of - Just pe -> checkPrimExp pe - Nothing -> pure False - -checkPrimExp :: PrimExp VName -> RefineM Bool -checkPrimExp (PE.BinOpExp PE.LogAnd x y) = - (&&) <$> checkPrimExp x <*> checkPrimExp y -checkPrimExp (PE.BinOpExp PE.LogOr x y) = - (||) <$> checkPrimExp x <*> checkPrimExp y -checkPrimExp pe@(PE.CmpOpExp cop x y) = do - (_, sop) <- toNumSoPCmp pe +checkExp e = do + (_, sop) <- toSoPCmp e sop $>=$ zeroSoP -checkPrimExp pe = pure False -runRefineM :: VNameSource -> RefineM a -> (a, AlgEnv VName, VNameSource) +runRefineM :: VNameSource -> RefineM a -> (a, AlgEnv VName Exp, VNameSource) runRefineM src (RefineM m) = let ((a, algenv), src', _) = runRWS (runSoPMT_ m) mempty src in (a, algenv, src') diff --git a/src/Futhark/SoP/Convert.hs b/src/Futhark/SoP/Convert.hs new file mode 100644 index 0000000000..c0767a4e54 --- /dev/null +++ b/src/Futhark/SoP/Convert.hs @@ -0,0 +1,233 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} + +-- | Translating to-and-from PrimExp to the sum-of-product representation. +module Futhark.SoP.Convert + ( FromSoP (..), + ToSoP (..), + ) +where + +import Control.Monad.State +import Data.List (find) +import Data.Set (Set) +import Data.Set qualified as S +import Futhark.Analysis.PrimExp (PrimExp, PrimType, (~*~), (~+~), (~-~), (~/~), (~==~)) +import Futhark.Analysis.PrimExp qualified as PE +import Futhark.SoP.Monad +import Futhark.SoP.SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark.Core +import Language.Futhark.Prop +import Language.Futhark.Syntax (VName) +import Language.Futhark.Syntax qualified as E + +-- | Conversion from 'SoP's to other representations. +class FromSoP u a where + fromSoP :: SoP u -> a + +instance Ord u => FromSoP u (PrimExp u) where + fromSoP sop = + foldr ((~+~) . fromTerm) (PE.ValueExp $ PE.IntValue $ PE.intValue PE.Int64 (0 :: Integer)) (sopToLists sop) + where + fromTerm (term, n) = + foldl (~*~) (PE.ValueExp $ PE.IntValue $ PE.intValue PE.Int64 n) $ + map fromSym term + fromSym sym = PE.LeafExp sym $ PE.IntType PE.Int64 + +-- instance FromSoP VName Exp where +-- fromSoP sop = undefined +-- where +-- -- foldr ((~+~) . fromTerm) (PE.ValueExp $ PE.IntValue $ PE.intValue PE.Int64 (0 :: Integer)) (sopToLists sop) +-- mult = (E.AppExp (E.Var (E.QualName [] (VName "*" 0)) (E.Info $ i64) mempty) (E.Info $ E.AppRes i64 [])) +-- fromTerm (term, n) = +-- foldl mult (E.Literal $ E.SignedValue $ PE.intValue PE.Int64 n) $ +-- map fromSym term +-- fromSym sym = E.Var (E.QualName [] sym) (E.Info i64) mempty +-- i64 = E.Scalar $ E.Prim $ E.Signed $ PE.Int64 + +-- | Conversion from some expressions to +-- 'SoP's. Monadic because it may involve look-ups in the +-- untranslatable expression environment. +-- +-- Separating into two functions is to make clearer the fact that +-- 'toSoPCmp' returns SoPs @sop@ implicitly in the relation @sop >= +-- 0@. Maybe this should be enforced at the constructor level +-- instead; i.e. have constructors for numeric SoPs and SoPs in +-- relations. +class ToSoP u e where + toSoPNum :: MonadSoP u e m => e -> m (Integer, SoP u) + + -- | Translates a 'PrimExp' containing a (top-level) comparison + -- operator into a 'SoP' representation such that @sop >= 0@. + toSoPCmp :: MonadSoP u e m => e -> m (Integer, SoP u >= 0) + +instance (Nameable u, Ord u, Show u, Pretty u) => ToSoP u (PrimExp u) where + toSoPNum primExp = do + (f, sop) <- toSoPNum' 1 primExp + pure (abs f, signum f `scaleSoP` sop) + where + notIntType :: PrimType -> Bool + notIntType (PE.IntType _) = False + notIntType _ = True + + divideIsh :: PE.BinOp -> Bool + divideIsh (PE.UDiv _ _) = True + divideIsh (PE.UDivUp _ _) = True + divideIsh (PE.SDiv _ _) = True + divideIsh (PE.SDivUp _ _) = True + divideIsh (PE.FDiv _) = True + divideIsh _ = False + toSoPNum' _ pe + | notIntType (PE.primExpType pe) = + error "toSoPNum' applied to a PrimExp whose prim type is not Integer" + toSoPNum' f (PE.LeafExp vnm _) = + pure (f, sym2SoP vnm) + toSoPNum' f (PE.ValueExp (PE.IntValue iv)) = + pure (1, int2SoP $ getIntVal iv `div` f) + where + getIntVal :: PE.IntValue -> Integer + getIntVal (PE.Int8Value v) = fromIntegral v + getIntVal (PE.Int16Value v) = fromIntegral v + getIntVal (PE.Int32Value v) = fromIntegral v + getIntVal (PE.Int64Value v) = fromIntegral v + toSoPNum' f (PE.UnOpExp PE.Complement {} x) = do + (f', x_sop) <- toSoPNum' f x + pure (f', negSoP x_sop) + toSoPNum' f (PE.BinOpExp PE.Add {} x y) = do + (x_f, x_sop) <- toSoPNum x + (y_f, y_sop) <- toSoPNum y + let l_c_m = lcm x_f y_f + (x_m, y_m) = (l_c_m `div` x_f, l_c_m `div` y_f) + x_sop' = mulSoPs (int2SoP x_m) x_sop + y_sop' = mulSoPs (int2SoP y_m) y_sop + pure (f * l_c_m, addSoPs x_sop' y_sop') + toSoPNum' f (PE.BinOpExp PE.Sub {} x y) = do + (x_f, x_sop) <- toSoPNum x + (y_f, y_sop) <- toSoPNum y + let l_c_m = lcm x_f y_f + (x_m, y_m) = (l_c_m `div` x_f, l_c_m `div` y_f) + x_sop' = mulSoPs (int2SoP x_m) x_sop + n_y_sop' = mulSoPs (int2SoP (-y_m)) y_sop + pure (f * l_c_m, addSoPs x_sop' n_y_sop') + toSoPNum' f pe@(PE.BinOpExp PE.Mul {} x y) = do + (x_f, x_sop) <- toSoPNum x + (y_f, y_sop) <- toSoPNum y + case (x_f, y_f) of + (1, 1) -> pure (f, mulSoPs x_sop y_sop) + _ -> do + x' <- lookupUntransPE pe + toSoPNum' f $ PE.LeafExp x' $ PE.primExpType pe + -- pe / 1 == pe + toSoPNum' f (PE.BinOpExp divish pe q) + | divideIsh divish && PE.oneIshExp q = + toSoPNum' f pe + -- evaluate `val_x / val_y` + toSoPNum' f (PE.BinOpExp divish x y) + | divideIsh divish, + PE.ValueExp v_x <- x, + PE.ValueExp v_y <- y = do + let f' = v_x `vdiv` v_y + toSoPNum' f $ PE.ValueExp f' + -- Trivial simplifications: + -- (y * v) / y = v and (u * y) / y = u + | divideIsh divish, + PE.BinOpExp (PE.Mul _ _) u v <- x, + (is_fst, is_snd) <- (u == y, v == y), + is_fst || is_snd = do + toSoPNum' f $ if is_fst then v else u + where + vdiv (PE.IntValue (PE.Int64Value x')) (PE.IntValue (PE.Int64Value y')) = + PE.IntValue $ PE.Int64Value (x' `div` y') + vdiv (PE.IntValue (PE.Int32Value x')) (PE.IntValue (PE.Int32Value y')) = + PE.IntValue $ PE.Int32Value (x' `div` y') + vdiv (PE.IntValue (PE.Int16Value x')) (PE.IntValue (PE.Int16Value y')) = + PE.IntValue $ PE.Int16Value (x' `div` y') + vdiv (PE.IntValue (PE.Int8Value x')) (PE.IntValue (PE.Int8Value y')) = + PE.IntValue $ PE.Int8Value (x' `div` y') + -- vdiv (FloatValue (Float32Value x)) (FloatValue (Float32Value y)) = + -- FloatValue $ Float32Value $ x / y + -- vdiv (FloatValue (Float64Value x)) (FloatValue (Float64Value y)) = + -- FloatValue $ Float64Value $ x / y + vdiv _ _ = error "In vdiv: illegal type for division!" + -- try heuristic for exact division + toSoPNum' f pe@(PE.BinOpExp divish x y) + | divideIsh divish = do + (x_f, x_sop) <- toSoPNum x + (y_f, y_sop) <- toSoPNum y + case (x_f, y_f, divSoPs x_sop y_sop) of + (1, 1, Just res) -> pure (f, res) + _ -> do + x' <- lookupUntransPE pe + toSoPNum' f $ PE.LeafExp x' $ PE.primExpType pe + -- Anything that is not handled by specific cases of toSoPNum' + -- is handled by this default procedure: + -- If the target `pe` is in the unknwon `env` + -- Then return thecorresponding binding + -- Else make a fresh symbol `v`, bind it in the environment + -- and return it. + toSoPNum' f pe = do + x <- lookupUntransPE pe + toSoPNum' f $ PE.LeafExp x $ PE.primExpType pe + + toSoPCmp (PE.CmpOpExp (PE.CmpEq ptp) x y) + -- x = y => x - y = 0 + | PE.IntType {} <- ptp = toSoPNum $ x ~-~ y + toSoPCmp (PE.CmpOpExp lessop x y) + -- x < y => x + 1 <= y => y >= x + 1 => y - (x+1) >= 0 + | Just itp <- lthishType lessop = + toSoPNum $ y ~-~ (x ~+~ PE.ValueExp (PE.IntValue $ PE.intValue itp (1 :: Integer))) + -- x <= y => y >= x => y - x >= 0 + | Just _ <- leqishType lessop = + toSoPNum $ y ~-~ x + where + lthishType (PE.CmpSlt itp) = Just itp + lthishType (PE.CmpUlt itp) = Just itp + lthishType _ = Nothing + leqishType (PE.CmpUle itp) = Just itp + leqishType (PE.CmpSle itp) = Just itp + leqishType _ = Nothing + toSoPCmp pe = error $ "toSoPCmp: not a comparison " <> prettyString pe + +instance (Nameable u, Ord u, Show u, Pretty u) => ToSoP u Exp where + toSoPNum (E.Literal v _) = + (pure . (1,)) $ + case v of + E.SignedValue x -> int2SoP $ PE.valueIntegral x + E.UnsignedValue x -> int2SoP $ PE.valueIntegral x + _ -> error "" + toSoPNum e = do + x <- lookupUntransPE e + pure (1, sym2SoP x) + + -- expToPrimExp (IntLit v (Info t) _) = + + toSoPCmp (E.AppExp (E.BinOp (op, _) _ (e_x, _) (e_y, _) _) _) + | E.baseTag (E.qualLeaf op) <= maxIntrinsicTag, + name <- E.baseString $ E.qualLeaf op, + Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = do + (_, x) <- toSoPNum e_x + (_, y) <- toSoPNum e_y + (1,) + <$> case bop of + E.Equal -> pure $ x .-. y + E.Less -> pure $ y .-. (x .+. int2SoP 1) + E.Leq -> pure $ y .-. x + E.Greater -> pure $ x .-. (y .+. int2SoP 1) + E.Geq -> pure $ x .-. y + +-- +-- {-- +---- This is a more refined treatment, but probably +---- an overkill (harmful if you get the type wrong) +-- fromSym unknowns sym +-- | Nothing <- M.lookup sym (dir unknowns) = +-- LeafExp sym $ IntType Integer +-- | Just pe1 <- M.lookup sym (dir unknowns), +-- IntType Integer <- PE.primExpType pe1 = +-- pe1 +-- fromSym unknowns sym = +-- error ("Type error in fromSym: type of " ++ +-- show sym ++ " is not Integer") +----} diff --git a/src/Futhark/SoP/Expression.hs b/src/Futhark/SoP/Expression.hs new file mode 100644 index 0000000000..c2d93d7d8a --- /dev/null +++ b/src/Futhark/SoP/Expression.hs @@ -0,0 +1,106 @@ +{-# LANGUAGE DataKinds #-} + +module Futhark.SoP.Expression + ( Expression (..), + processExps, + ) +where + +import Data.Set (Set) +import Data.Set qualified as S +import Futhark.Analysis.PrimExp +import Futhark.SoP.Util +import Language.Futhark.Prop + +class Expression e where + -- -- | Is this 'PrimType' not integral? + -- notIntType :: PrimType -> Bool + + -- | Is this expression @mod@? + moduloIsh :: e -> Maybe (e, e) + + -- -- | Is this 'PrimExp' @<@? + -- lthishType :: CmpOp -> Maybe IntType + + -- -- | Is this 'PrimExp' @<=@? + -- leqishType :: CmpOp -> Maybe IntType + + -- | Rewrite a mod expression into division. + divInsteadOfMod :: e -> e + + -- | Algebraically manipulates an 'e' into a set of equality + -- and inequality constraints. + processExp :: e -> (Set (e == 0), Set (e >= 0)) + +processExps :: (Ord e, Expression e, Foldable t) => t e -> (Set (e == 0), Set (e >= 0)) +processExps = foldMap processExp + +instance Expression Exp + +instance Ord u => Expression (PrimExp u) where + moduloIsh :: PrimExp u -> Maybe (PrimExp u, PrimExp u) + moduloIsh (BinOpExp (SMod _ _) pe1 pe2) = Just (pe1, pe2) + moduloIsh (BinOpExp (UMod _ _) pe1 pe2) = Just (pe1, pe2) + moduloIsh _ = Nothing + + processExp :: PrimExp u -> (Set (PrimExp u == 0), Set (PrimExp u >= 0)) + processExp (CmpOpExp (CmpEq ptp) x y) + -- x = y => x - y = 0 + | IntType {} <- ptp = + (S.singleton (x ~-~ y), mempty) + processExp (CmpOpExp lessop x y) + -- x < y => x + 1 <= y => y >= x + 1 => y - (x+1) >= 0 + | Just itp <- lthishType lessop = + let pe = y ~-~ (x ~+~ ValueExp (IntValue $ intValue itp (1 :: Integer))) + in (mempty, S.singleton pe) + -- x <= y => y >= x => y - x >= 0 + | Just _ <- leqishType lessop = + (mempty, S.singleton $ y ~-~ x) + where + -- Is this 'PrimExp' @<@? + lthishType :: CmpOp -> Maybe IntType + lthishType (CmpSlt itp) = Just itp + lthishType (CmpUlt itp) = Just itp + lthishType _ = Nothing + + -- Is this 'PrimExp' @<=@? + leqishType :: CmpOp -> Maybe IntType + leqishType (CmpUle itp) = Just itp + leqishType (CmpSle itp) = Just itp + leqishType _ = Nothing + processExp (BinOpExp LogAnd x y) = + processExps [x, y] + processExp (CmpOpExp CmpEq {} pe1 pe2) = + case (pe1, pe2) of + -- (x && y) == True => x && y + (BinOpExp LogAnd _ _, ValueExp (BoolValue True)) -> + processExp pe1 + -- True == (x && y) => x && y + (ValueExp (BoolValue True), BinOpExp LogAnd _ _) -> + processExp pe2 + -- (x || y) == False => !x && !y + (BinOpExp LogOr x y, ValueExp (BoolValue False)) -> + processExps [UnOpExp Not x, UnOpExp Not y] + -- False == (x || y) => !x && !y + (ValueExp (BoolValue False), BinOpExp LogOr x y) -> + processExps [UnOpExp Not x, UnOpExp Not y] + _ -> mempty + processExp (UnOpExp Not pe) = + case pe of + -- !(!x) => x + UnOpExp Not x -> + processExp x + -- !(x < y) => y <= x + CmpOpExp (CmpSlt itp) x y -> + processExp $ CmpOpExp (CmpSle itp) y x + -- !(x <= y) => y < x + CmpOpExp (CmpSle itp) x y -> + processExp $ CmpOpExp (CmpSlt itp) y x + -- !(x < y) => y <= x + CmpOpExp (CmpUlt itp) x y -> + processExp $ CmpOpExp (CmpUle itp) y x + -- !(x <= y) => y < x + CmpOpExp (CmpUle itp) x y -> + processExp $ CmpOpExp (CmpUlt itp) y x + _ -> mempty + processExp _ = mempty diff --git a/src/Futhark/SoP/FourierMotzkin.hs b/src/Futhark/SoP/FourierMotzkin.hs index 7d36e10775..1b13977390 100644 --- a/src/Futhark/SoP/FourierMotzkin.hs +++ b/src/Futhark/SoP/FourierMotzkin.hs @@ -22,10 +22,10 @@ module Futhark.SoP.FourierMotzkin fmSolveLEq0, fmSolveGTh0, fmSolveGEq0, - fmSolveLTh0_, - fmSolveGTh0_, - fmSolveGEq0_, - fmSolveLEq0_, + -- fmSolveLTh0_, + -- fmSolveGTh0_, + -- fmSolveGEq0_, + -- fmSolveLEq0_, ($<$), ($<=$), ($>$), @@ -51,19 +51,19 @@ import Futhark.Util.Pretty -- | Solves the inequation `sop < 0` by reducing it to -- `sop + 1 <= 0`, where `sop` denotes an expression -- in sum-of-product form. -fmSolveLTh0 :: MonadSoP u m => SoP u -> m Bool +fmSolveLTh0 :: MonadSoP u e m => SoP u -> m Bool fmSolveLTh0 = fmSolveLEq0 . (.+. int2SoP 1) -- | Solves the inequation `sop > 0` by reducing it to -- `(-1)*sop < 0`, where `sop` denotes an expression -- in sum-of-product form. -fmSolveGTh0 :: MonadSoP u m => SoP u -> m Bool +fmSolveGTh0 :: MonadSoP u e m => SoP u -> m Bool fmSolveGTh0 = fmSolveLTh0 . negSoP -- | Solves the inequation `sop >= 0` by reducing it to -- `(-1)*sop <= 0`, where `sop` denotes an expression -- in sum-of-product form. -fmSolveGEq0 :: MonadSoP u m => SoP u -> m Bool +fmSolveGEq0 :: MonadSoP u e m => SoP u -> m Bool fmSolveGEq0 = fmSolveLEq0 . negSoP -- | Assuming `sop` an expression in sum-of-products (SoP) form, @@ -81,7 +81,7 @@ fmSolveGEq0 = fmSolveLEq0 . negSoP -- (i) `True` if the inequality is found to always holds; -- (ii) `False` if there is an `i` for which the inequality does -- not hold or if the answer is unknown. -fmSolveLEq0 :: MonadSoP u m => SoP u -> m Bool +fmSolveLEq0 :: MonadSoP u e m => SoP u -> m Bool fmSolveLEq0 sop | Just v <- justConstant sop = pure (v <= 0) | not (null syms) = do @@ -115,29 +115,29 @@ fmSolveLEq0 sop where syms = S.toList $ free sop -($<$) :: MonadSoP u m => SoP u -> SoP u -> m Bool +($<$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool x $<$ y = fmSolveLTh0 $ x .-. y -($<=$) :: MonadSoP u m => SoP u -> SoP u -> m Bool +($<=$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool x $<=$ y = fmSolveLEq0 $ x .-. y -($>$) :: MonadSoP u m => SoP u -> SoP u -> m Bool +($>$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool x $>$ y = fmSolveGTh0 $ x .-. y -($>=$) :: MonadSoP u m => SoP u -> SoP u -> m Bool +($>=$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool x $>=$ y = fmSolveGEq0 $ x .-. y -($==$) :: MonadSoP u m => SoP u -> SoP u -> m Bool +($==$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool x $==$ y = (&&) <$> (x $<=$ y) <*> (x $>=$ y) -fmSolveLTh0_ :: (Ord u, Nameable u, Show u, Pretty u) => RangeEnv u -> SoP u -> Bool -fmSolveLTh0_ rs = evalSoPM mempty {ranges = rs} . fmSolveLTh0 - -fmSolveGTh0_ :: (Ord u, Nameable u, Show u, Pretty u) => RangeEnv u -> SoP u -> Bool -fmSolveGTh0_ rs = evalSoPM mempty {ranges = rs} . fmSolveGTh0 - -fmSolveGEq0_ :: (Ord u, Nameable u, Show u, Pretty u) => RangeEnv u -> SoP u -> Bool -fmSolveGEq0_ rs = evalSoPM mempty {ranges = rs} . fmSolveGEq0 - -fmSolveLEq0_ :: (Ord u, Nameable u, Show u, Pretty u) => RangeEnv u -> SoP u -> Bool -fmSolveLEq0_ rs = evalSoPM mempty {ranges = rs} . fmSolveLEq0 +-- fmSolveLTh0_ :: (Ord u, Nameable u, Show u, Pretty u) => RangeEnv u -> SoP u -> Bool +-- fmSolveLTh0_ rs = evalSoPM mempty {ranges = rs} . fmSolveLTh0 +-- +-- fmSolveGTh0_ :: (Ord u, Nameable u, Show u, Pretty u) => RangeEnv u -> SoP u -> Bool +-- fmSolveGTh0_ rs = evalSoPM mempty {ranges = rs} . fmSolveGTh0 +-- +-- fmSolveGEq0_ :: (Ord u, Nameable u, Show u, Pretty u) => RangeEnv u -> SoP u -> Bool +-- fmSolveGEq0_ rs = evalSoPM mempty {ranges = rs} . fmSolveGEq0 +-- +-- fmSolveLEq0_ :: (Ord u, Nameable u, Show u, Pretty u) => RangeEnv u -> SoP u -> Bool +-- fmSolveLEq0_ rs = evalSoPM mempty {ranges = rs} . fmSolveLEq0 diff --git a/src/Futhark/SoP/Monad.hs b/src/Futhark/SoP/Monad.hs index a8369fdf51..dfd7f19e06 100644 --- a/src/Futhark/SoP/Monad.hs +++ b/src/Futhark/SoP/Monad.hs @@ -1,6 +1,4 @@ -{-# LANGUAGE DataKinds #-} {-# LANGUAGE FunctionalDependencies #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -- | The Algebraic Environment, which is in principle @@ -13,8 +11,6 @@ module Futhark.SoP.Monad EquivEnv, UntransEnv (..), AlgEnv (..), - type (>=), - type (==), addUntrans, transClosInRanges, lookupUntransPE, @@ -45,9 +41,9 @@ import Data.Set qualified as S import Futhark.Analysis.PrimExp import Futhark.FreshNames import Futhark.MonadFreshNames +import Futhark.SoP.Expression import Futhark.SoP.SoP import Futhark.Util.Pretty -import GHC.TypeLits (Natural) import Language.Futhark.Syntax hiding (Range) -------------------------------------------------------------------------------- @@ -77,40 +73,44 @@ mkNameM = modifyNameSource mkName class ( Ord u, + Ord e, Nameable u, Show u, -- To be removed Pretty u, -- To be removed - MonadFreshNames m + MonadFreshNames m, + Substitute u e e, + Expression e ) => - MonadSoP u m - | m -> u + MonadSoP u e m + | m -> u, + m -> e where - getUntrans :: m (UntransEnv u) + getUntrans :: m (UntransEnv u e) getRanges :: m (RangeEnv u) getEquivs :: m (EquivEnv u) - modifyEnv :: (AlgEnv u -> AlgEnv u) -> m () + modifyEnv :: (AlgEnv u e -> AlgEnv u e) -> m () -- | The algebraic monad; consists of a an algebraic -- environment along with a fresh variable source. -newtype SoPMT u m a = SoPMT (StateT (AlgEnv u) m a) +newtype SoPMT u e m a = SoPMT (StateT (AlgEnv u e) m a) deriving ( Functor, Applicative, Monad ) -instance MonadTrans (SoPMT u) where +instance MonadTrans (SoPMT u e) where lift = SoPMT . lift -instance MonadFreshNames m => MonadFreshNames (SoPMT u m) where +instance MonadFreshNames m => MonadFreshNames (SoPMT u e m) where getNameSource = lift getNameSource putNameSource = lift . putNameSource -instance (MonadFreshNames m) => MonadFreshNames (StateT (AlgEnv u) m) where +instance (MonadFreshNames m) => MonadFreshNames (StateT (AlgEnv u e) m) where getNameSource = lift getNameSource putNameSource = lift . putNameSource -instance MonadReader r m => MonadReader r (SoPMT u m) where +instance MonadReader r m => MonadReader r (SoPMT u e m) where ask = SoPMT $ lift ask local f (SoPMT m) = SoPMT $ do @@ -119,44 +119,47 @@ instance MonadReader r m => MonadReader r (SoPMT u m) where put env' pure a -instance MonadState s m => MonadState s (SoPMT u m) where +instance MonadState s m => MonadState s (SoPMT u e m) where get = SoPMT $ lift get put = SoPMT . lift . put -type SoPM u = SoPMT u (State VNameSource) +type SoPM u e = SoPMT u e (State VNameSource) -runSoPMT :: MonadFreshNames m => AlgEnv u -> SoPMT u m a -> m (a, AlgEnv u) +runSoPMT :: MonadFreshNames m => AlgEnv u e -> SoPMT u e m a -> m (a, AlgEnv u e) runSoPMT env (SoPMT sm) = runStateT sm env -runSoPMT_ :: (Ord u, MonadFreshNames m) => SoPMT u m a -> m (a, AlgEnv u) +runSoPMT_ :: (Ord u, Ord e, MonadFreshNames m) => SoPMT u e m a -> m (a, AlgEnv u e) runSoPMT_ = runSoPMT mempty -runSoPM :: Ord u => AlgEnv u -> SoPM u a -> (a, AlgEnv u) +runSoPM :: (Ord u, Ord e) => AlgEnv u e -> SoPM u e a -> (a, AlgEnv u e) runSoPM env = flip evalState mempty . runSoPMT env -runSoPM_ :: Ord u => SoPM u a -> (a, AlgEnv u) +runSoPM_ :: (Ord u, Ord e) => SoPM u e a -> (a, AlgEnv u e) runSoPM_ = runSoPM mempty -evalSoPMT :: MonadFreshNames m => AlgEnv u -> SoPMT u m a -> m a +evalSoPMT :: MonadFreshNames m => AlgEnv u e -> SoPMT u e m a -> m a evalSoPMT env m = fst <$> runSoPMT env m -evalSoPMT_ :: (Ord u, MonadFreshNames m) => SoPMT u m a -> m a +evalSoPMT_ :: (Ord u, Ord e, MonadFreshNames m) => SoPMT u e m a -> m a evalSoPMT_ = evalSoPMT mempty -evalSoPM :: Ord u => AlgEnv u -> SoPM u a -> a +evalSoPM :: (Ord u, Ord e) => AlgEnv u e -> SoPM u e a -> a evalSoPM env = fst . runSoPM env -evalSoPM_ :: Ord u => SoPM u a -> a +evalSoPM_ :: (Ord u, Ord e) => SoPM u e a -> a evalSoPM_ = evalSoPM mempty instance ( Ord u, + Ord e, Nameable u, Show u, Pretty u, - MonadFreshNames m + MonadFreshNames m, + Substitute u e e, + Expression e ) => - MonadSoP u (SoPMT u m) + MonadSoP u e (SoPMT u e m) where getUntrans = SoPMT $ gets untrans @@ -167,7 +170,7 @@ instance modifyEnv f = SoPMT $ modify f -- \| Insert a symbol equal to an untranslatable 'PrimExp'. -addUntrans :: MonadSoP u m => u -> PrimExp u -> m () +addUntrans :: MonadSoP u e m => u -> e -> m () addUntrans sym pe = modifyEnv $ \env -> env @@ -179,12 +182,12 @@ addUntrans sym pe = } -- \| Look-up the sum-of-products representation of a symbol. -lookupSoP :: MonadSoP u m => u -> m (Maybe (SoP u)) +lookupSoP :: MonadSoP u e m => u -> m (Maybe (SoP u)) lookupSoP x = (M.!? x) <$> getEquivs -- \| Look-up the symbol for a 'PrimExp'. If no symbol is bound -- to the expression, bind a new one. -lookupUntransPE :: MonadSoP u m => PrimExp u -> m u +lookupUntransPE :: MonadSoP u e m => e -> m u lookupUntransPE pe = do inv_map <- inv <$> getUntrans case inv_map M.!? pe of @@ -195,12 +198,12 @@ lookupUntransPE pe = do Just x -> pure x -- \| Look-up the untranslatable 'PrimExp' bound to the given symbol. -lookupUntransSym :: MonadSoP u m => u -> m (Maybe (PrimExp u)) +lookupUntransSym :: MonadSoP u e m => u -> m (Maybe e) lookupUntransSym sym = ((M.!? sym) . dir) <$> getUntrans -- \| Look-up the range of a symbol. If no such range exists, -- return the empty range (and add it to the environment). -lookupRange :: MonadSoP u m => u -> m (Range u) +lookupRange :: MonadSoP u e m => u -> m (Range u) lookupRange sym = do mr <- (M.!? sym) <$> getRanges case mr of @@ -214,7 +217,7 @@ lookupRange sym = do -- \| Add range information for a symbol; augments the existing -- range. -addRange :: MonadSoP u m => u -> Range u -> m () +addRange :: MonadSoP u e m => u -> Range u -> m () addRange sym r = modifyEnv $ \env -> env {ranges = M.insertWith (<>) sym r (ranges env)} @@ -223,28 +226,22 @@ addRange sym r = -- Environment -------------------------------------------------------------------------------- --- | A type label to indicate @a >= 0@. -type a >= (b :: Natural) = a - --- | A type label to indicate @a = 0@. -type a == (b :: Natural) = a - -- | The environment of untranslatable 'PrimeExp's. It maps both -- ways: -- -- 1. A fresh symbol is generated and mapped to the -- corresponding 'PrimeExp' @pe@ in 'dir'. -- 2. The target @pe@ is mapped backed to the corresponding symbol in 'inv'. -data UntransEnv u = Unknowns - { dir :: Map u (PrimExp u), - inv :: Map (PrimExp u) u +data UntransEnv u e = Unknowns + { dir :: Map u e, + inv :: Map e u } deriving (Eq, Show, Ord) -instance Ord u => Semigroup (UntransEnv u) where +instance (Ord u, Ord e) => Semigroup (UntransEnv u e) where Unknowns d1 i1 <> Unknowns d2 i2 = Unknowns (d1 <> d2) (i1 <> i2) -instance Ord u => Monoid (UntransEnv u) where +instance (Ord u, Ord e) => Monoid (UntransEnv u e) where mempty = Unknowns mempty mempty -- | The equivalence environment binds a variable name to @@ -258,9 +255,9 @@ instance Pretty u => Pretty (RangeEnv u) where pretty = pretty . M.toList -- | The main algebraic environment. -data AlgEnv u = AlgEnv +data AlgEnv u e = AlgEnv { -- | Binds untranslatable PrimExps to fresh symbols. - untrans :: UntransEnv u, + untrans :: UntransEnv u e, -- | Binds symbols to their sum-of-product representation.. equivs :: EquivEnv u, -- | Binds symbols to ranges (in sum-of-product form). @@ -268,11 +265,11 @@ data AlgEnv u = AlgEnv } deriving (Ord, Show, Eq) -instance Ord u => Semigroup (AlgEnv u) where +instance (Ord u, Ord e) => Semigroup (AlgEnv u e) where AlgEnv u1 s1 r1 <> AlgEnv u2 s2 r2 = AlgEnv (u1 <> u2) (s1 <> s2) (r1 <> r2) -instance Ord u => Monoid (AlgEnv u) where +instance (Ord u, Ord e) => Monoid (AlgEnv u e) where mempty = AlgEnv mempty mempty mempty transClosInRanges :: (Ord u) => RangeEnv u -> Set u -> Set u diff --git a/src/Futhark/SoP/PrimExp.hs b/src/Futhark/SoP/PrimExp.hs deleted file mode 100644 index c1bcd82917..0000000000 --- a/src/Futhark/SoP/PrimExp.hs +++ /dev/null @@ -1,116 +0,0 @@ -{-# LANGUAGE DataKinds #-} - --- | Basic 'PrimExp' functions. -module Futhark.SoP.PrimExp - ( notIntType, - divideIsh, - moduloIsh, - divInsteadOfMod, - processPE, - processPEs, - ) -where - -import Data.Set (Set) -import Data.Set qualified as S -import Futhark.Analysis.PrimExp -import Futhark.SoP.Monad - --- | Is this 'PrimType' not integral? -notIntType :: PrimType -> Bool -notIntType (IntType _) = False -notIntType _ = True - --- | Is this 'BinOp' division? -divideIsh :: BinOp -> Bool -divideIsh (UDiv _ _) = True -divideIsh (UDivUp _ _) = True -divideIsh (SDiv _ _) = True -divideIsh (SDivUp _ _) = True -divideIsh (FDiv _) = True -divideIsh _ = False - --- | Is this 'PrimExp' @mod@? -moduloIsh :: PrimExp u -> Maybe (PrimExp u, PrimExp u) -moduloIsh (BinOpExp (SMod _ _) pe1 pe2) = Just (pe1, pe2) -moduloIsh (BinOpExp (UMod _ _) pe1 pe2) = Just (pe1, pe2) -moduloIsh _ = Nothing - --- | Is this 'PrimExp' @<@? -lthishType :: CmpOp -> Maybe IntType -lthishType (CmpSlt itp) = Just itp -lthishType (CmpUlt itp) = Just itp -lthishType _ = Nothing - --- | Is this 'PrimExp' @<=@? -leqishType :: CmpOp -> Maybe IntType -leqishType (CmpUle itp) = Just itp -leqishType (CmpSle itp) = Just itp -leqishType _ = Nothing - --- | Rewrite a mod expression into division. -divInsteadOfMod :: Show u => PrimExp u -> PrimExp u -divInsteadOfMod (BinOpExp (UMod itp saf) pe1 pe2) = - BinOpExp (UDiv itp saf) pe1 pe2 -divInsteadOfMod (BinOpExp (SMod itp saf) pe1 pe2) = - BinOpExp (SDiv itp saf) pe1 pe2 -divInsteadOfMod pe = error ("Impossible case reached in divInsteadOfMod!" ++ show pe) - --- | Algebraically manipualtes a 'PrimExp' into a set of equality --- and inequality constraints. -processPE :: (Ord u, Nameable u) => PrimExp u -> (Set (PrimExp u == 0), Set (PrimExp u >= 0)) -processPE (CmpOpExp (CmpEq ptp) x y) - -- x = y => x - y = 0 - | IntType {} <- ptp = - (S.singleton (x ~-~ y), mempty) -processPE (CmpOpExp lessop x y) - -- x < y => x + 1 <= y => y >= x + 1 => y - (x+1) >= 0 - | Just itp <- lthishType lessop = - let pe = y ~-~ (x ~+~ ValueExp (IntValue $ intValue itp (1 :: Integer))) - in (mempty, S.singleton pe) - -- x <= y => y >= x => y - x >= 0 - | Just _ <- leqishType lessop = - (mempty, S.singleton $ y ~-~ x) -processPE (BinOpExp LogAnd x y) = - processPEs [x, y] -processPE (CmpOpExp CmpEq {} pe1 pe2) = - case (pe1, pe2) of - -- (x && y) == True => x && y - (BinOpExp LogAnd _ _, ValueExp (BoolValue True)) -> - processPE pe1 - -- True == (x && y) => x && y - (ValueExp (BoolValue True), BinOpExp LogAnd _ _) -> - processPE pe2 - -- (x || y) == False => !x && !y - (BinOpExp LogOr x y, ValueExp (BoolValue False)) -> - processPEs [UnOpExp Not x, UnOpExp Not y] - -- False == (x || y) => !x && !y - (ValueExp (BoolValue False), BinOpExp LogOr x y) -> - processPEs [UnOpExp Not x, UnOpExp Not y] - _ -> mempty -processPE (UnOpExp Not pe) = - case pe of - -- !(!x) => x - UnOpExp Not x -> - processPE x - -- !(x < y) => y <= x - CmpOpExp (CmpSlt itp) x y -> - processPE $ CmpOpExp (CmpSle itp) y x - -- !(x <= y) => y < x - CmpOpExp (CmpSle itp) x y -> - processPE $ CmpOpExp (CmpSlt itp) y x - -- !(x < y) => y <= x - CmpOpExp (CmpUlt itp) x y -> - processPE $ CmpOpExp (CmpUle itp) y x - -- !(x <= y) => y < x - CmpOpExp (CmpUle itp) x y -> - processPE $ CmpOpExp (CmpUlt itp) y x - _ -> mempty -processPE _ = mempty - --- | Process multiple `PrimExp`s at once. -processPEs :: - (Ord u, Nameable u, Foldable t) => - t (PrimExp u) -> - (Set (PrimExp u == 0), Set (PrimExp u >= 0)) -processPEs = foldMap processPE diff --git a/src/Futhark/SoP/Refine.hs b/src/Futhark/SoP/Refine.hs index 0122e84b3e..bffc620931 100644 --- a/src/Futhark/SoP/Refine.hs +++ b/src/Futhark/SoP/Refine.hs @@ -8,23 +8,21 @@ module Futhark.SoP.Refine where import Data.Set (Set) +import Data.Set qualified as S import Futhark.Analysis.PrimExp +import Futhark.SoP.Convert +import Futhark.SoP.Expression import Futhark.SoP.Monad -import Futhark.SoP.PrimExp import Futhark.SoP.RefineEquivs import Futhark.SoP.RefineRanges -import Futhark.SoP.ToFromSoP -refineAlgEnv :: - MonadSoP u m => - Set (PrimExp u) -> - m () +refineAlgEnv :: (FromSoP u e, ToSoP u e, MonadSoP u e m) => Set e -> m () refineAlgEnv candidates = do -- Split candidates into equality and inequality sets. - let (eqZs, ineqZs) = processPEs candidates + let (eqZs, ineqZs) = processExps candidates -- Refine the environment with the equality set. extra_ineqZs <- addEqZeroPEs eqZs -- Refine the environment with the extended inequality set. - addIneqZeroPEs $ ineqZs <> fromSoP extra_ineqZs + addIneqZeroPEs $ ineqZs <> S.map fromSoP extra_ineqZs diff --git a/src/Futhark/SoP/RefineEquivs.hs b/src/Futhark/SoP/RefineEquivs.hs index 0ccceee6c3..37feb00359 100644 --- a/src/Futhark/SoP/RefineEquivs.hs +++ b/src/Futhark/SoP/RefineEquivs.hs @@ -13,20 +13,21 @@ import Data.MultiSet qualified as MS import Data.Set (Set) import Data.Set qualified as S import Futhark.Analysis.PrimExp -import Futhark.Analysis.PrimExp.Convert +import Futhark.SoP.Convert +import Futhark.SoP.Expression import Futhark.SoP.FourierMotzkin import Futhark.SoP.Monad -import Futhark.SoP.PrimExp +-- import Futhark.SoP.PrimExp import Futhark.SoP.SoP -import Futhark.SoP.ToFromSoP +import Futhark.SoP.Util -- | Refine the environment with a set of 'PrimExp's with the assertion that @pe = 0@ -- for each 'PrimExp' in the set. -addEqZeroPEs :: MonadSoP u m => Set (PrimExp u == 0) -> m (Set (SoP u >= 0)) +addEqZeroPEs :: forall u e m. (ToSoP u e, FromSoP u e, MonadSoP u e m) => Set (e == 0) -> m (Set (SoP u >= 0)) addEqZeroPEs pes = do -- Substitute already known equivalences in the equality set. - equivs_pes <- fromSoP <$> getEquivs - let pes' = S.map (substituteInPrimExp equivs_pes) pes + equivs_pes <- (fmap . fmap) (fromSoP :: SoP u -> e) getEquivs + let pes' = S.map (substitute equivs_pes) pes -- Make equivalence candidates along with any extra constraints. (extra_inEqZs :: Set (SoP u >= 0), equiv_cands) <- mconcat <$> mapM addEquiv2CandSet (S.toList pes') @@ -62,7 +63,7 @@ instance Ord u => Substitute u (SoP u) (EquivCand u) where -- ToDo: try to give common factor first, e.g., -- nx - nbq - n = 0 => n*(x-bq-1) = 0 => x = bq+1, -- if we can prove that n != 0 -mkEquivCands :: MonadSoP u m => SoP u -> m (Set (EquivCand u)) +mkEquivCands :: MonadSoP u e m => SoP u -> m (Set (EquivCand u)) mkEquivCands sop = M.foldrWithKey mkEquivCand (pure mempty) $ getTerms sop where mkEquivCand (Term term) v mcands @@ -99,14 +100,14 @@ mkEquivCands sop = M.foldrWithKey mkEquivCand (pure mempty) $ getTerms sop -- * Possibly add the constraints @0 <= sop <= pe2 - 1@. -- -- 2: TODO: try to give common factors and get simpler. -refineEquivCand :: MonadSoP u m => EquivCand u -> m (Set (SoP u >= 0), EquivCand u) +refineEquivCand :: forall u e m. (ToSoP u e, MonadSoP u e m) => EquivCand u -> m (Set (SoP u >= 0), EquivCand u) refineEquivCand cand@(EquivCand sym sop) = do mpe <- lookupUntransSym sym case mpe of Just pe | Just (pe1, pe2) <- moduloIsh pe -> do - (f1, sop1) <- toNumSoP pe1 - (f2, sop2) <- toNumSoP pe2 + (f1, sop1) <- toSoPNum pe1 + (f2, sop2) <- toSoPNum pe2 is_pos <- fmSolveGEq0 sop2 case (f1, f2, justSym sop1, is_pos) of (1, 1, Just sym1, True) -> do @@ -132,17 +133,17 @@ refineEquivCand cand@(EquivCand sym sop) = do -- creation/refinement of the mapping. -- * @cands@: set of equivalence candidates. addEquiv2CandSet :: - MonadSoP u m => - PrimExp u == 0 -> + (ToSoP u e, MonadSoP u e m) => + e == 0 -> m (Set (SoP u >= 0), Set (EquivCand u)) addEquiv2CandSet pe = do - (_, sop) <- toNumSoP pe + (_, sop) <- toSoPNum pe cands <- mkEquivCands sop (ineqss, cands') <- mapAndUnzipM refineEquivCand $ S.toList cands pure (mconcat ineqss, S.fromList cands') -- | Add legal equivalence candidates to the environment. -addLegalCands :: MonadSoP u m => Set (EquivCand u) -> m () +addLegalCands :: MonadSoP u e m => Set (EquivCand u) -> m () addLegalCands cand_set | S.null cand_set = pure () addLegalCands cand_set = do diff --git a/src/Futhark/SoP/RefineRanges.hs b/src/Futhark/SoP/RefineRanges.hs index 8f9de24cbd..a563d9e14f 100644 --- a/src/Futhark/SoP/RefineRanges.hs +++ b/src/Futhark/SoP/RefineRanges.hs @@ -13,22 +13,22 @@ import Data.Set (Set) import Data.Set qualified as S import Futhark.Analysis.PrimExp import Futhark.Analysis.PrimExp.Convert +import Futhark.SoP.Convert import Futhark.SoP.FourierMotzkin import Futhark.SoP.Monad import Futhark.SoP.SoP -import Futhark.SoP.ToFromSoP import Futhark.SoP.Util -- | Refine the environment with a set of 'PrimExp's with the assertion that @pe >= 0@ -- for each 'PrimExp' in the set. -addIneqZeroPEs :: MonadSoP u m => Set (PrimExp u >= 0) -> m () +addIneqZeroPEs :: forall u e m. (ToSoP u e, FromSoP u e, MonadSoP u e m) => Set (e >= 0) -> m () addIneqZeroPEs pes = do -- Substitute equivalence env. - pes' <- (flip S.map pes . substituteInPrimExp . fmap fromSoP) <$> getEquivs + pes' <- (flip S.map pes . substitute . fmap (fromSoP :: SoP u -> e)) <$> getEquivs ineq_cands <- mconcat - <$> mapM (fmap (mkRangeCands . snd) . toNumSoP) (S.toList pes') + <$> mapM (fmap (mkRangeCands . snd) . toSoPNum) (S.toList pes') addRangeCands ineq_cands @@ -88,7 +88,7 @@ mkRangeCands sop = M.foldrWithKey mkRangeCand mempty $ getTerms sop -- these are @lbs' <= -j_z * sop@ (@j_z * sop <= ubs'@) where @lbs'@ -- (@ubs'@) are the refined bounds from the previous step. refineRangeInEnv :: - MonadSoP u m => + MonadSoP u e m => RangeCand u -> m (Set (RangeCand u)) refineRangeInEnv (RangeCand j sym sop) = do @@ -133,7 +133,7 @@ data CandRank | SymNotBound deriving (Ord, Eq) -addRangeCands :: MonadSoP u m => Set (RangeCand u) -> m () +addRangeCands :: MonadSoP u e m => Set (RangeCand u) -> m () addRangeCands cand_set | S.null cand_set = pure () addRangeCands cand_set = do diff --git a/src/Futhark/SoP/SoP.hs b/src/Futhark/SoP/SoP.hs index 6d8e42627f..a33674c239 100644 --- a/src/Futhark/SoP/SoP.hs +++ b/src/Futhark/SoP/SoP.hs @@ -45,6 +45,8 @@ import Data.Set qualified as S import Futhark.Analysis.PrimExp.Convert import Futhark.SoP.Util import Futhark.Util.Pretty +import Language.Futhark.Core +import Language.Futhark.Prop -- | A 'Term' is a product of symbols. newtype Term u = Term {getTerm :: MultiSet u} @@ -386,3 +388,6 @@ instance Ord u => Substitute u (SoP u) (Range u) where instance Ord u => Substitute u (PrimExp u) (PrimExp u) where substitute = substituteInPrimExp + +instance Substitute VName Exp Exp where + substitute = undefined diff --git a/src/Futhark/SoP/ToFromSoP.hs b/src/Futhark/SoP/ToFromSoP.hs deleted file mode 100644 index 8ae6296dcc..0000000000 --- a/src/Futhark/SoP/ToFromSoP.hs +++ /dev/null @@ -1,205 +0,0 @@ -{-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE DataKinds #-} - --- | Translating to-and-from PrimExp to the sum-of-product representation. -module Futhark.SoP.ToFromSoP - ( toNumSoP, - toNumSoPCmp, - fromNumSoP, - FromSoP (..), - ToSoPM (..), - ) -where - -import Control.Monad.State -import Data.Set (Set) -import Data.Set qualified as S -import Futhark.Analysis.PrimExp -import Futhark.SoP.Monad -import Futhark.SoP.PrimExp -import Futhark.SoP.SoP -import Futhark.Util.Pretty - --- | Conversion from (structures of) 'SoP's to other representations --- (e.g., 'PrimExp's). -class FromSoP a b where - fromSoP :: a -> b - -instance FromSoP (SoP u) (SoP u) where - fromSoP = id - -instance Ord u => FromSoP (SoP u) (PrimExp u) where - fromSoP = fromNumSoP - -instance (Functor t, FromSoP a b) => FromSoP (t a) (t b) where - fromSoP = fmap fromSoP - -instance {-# OVERLAPS #-} (Ord b, FromSoP a b) => FromSoP (Set a) (Set b) where - fromSoP = S.fromList . fromSoP . S.toList - --- | Conversion from (strctures of) some expressions to --- 'SoP's. Monadic because it may involve look-ups in the --- untranslatable expression environment. -class Monad m => ToSoPM m a b where - toSoP :: a -> m b - -instance (Nameable u, Ord u, Show u, Pretty u) => ToSoPM (SoPM u) (PrimExp u) (Integer, SoP u) where - toSoP = toNumSoP - -instance (Traversable t, ToSoPM m a b) => ToSoPM m (t a) (t b) where - toSoP = traverse toSoP - -instance {-# OVERLAPS #-} (Ord b, ToSoPM m a b) => ToSoPM m (Set a) (Set b) where - toSoP = fmap S.fromList . toSoP . S.toList - --- | Translates 'PrimExp's to a 'SoP' representation, scaled by the --- returned integer. --- --- TODO: please extend to return also an integral --- quotient, e.g., in order to support, e.g., @i <= (n+1)/16 + 3@. -toNumSoP :: MonadSoP u m => PrimExp u -> m (Integer, SoP u) -toNumSoP primExp = do - (f, sop) <- toNumSoP' 1 primExp - pure (abs f, signum f `scaleSoP` sop) - where - toNumSoP' :: (Ord u, Nameable u, MonadSoP u m) => Integer -> PrimExp u -> m (Integer, SoP u) - toNumSoP' _ pe - | notIntType (primExpType pe) = - error "toNumSoP applied to a PrimExp whose prim type is not Integer" - toNumSoP' f (LeafExp vnm _) = - pure (f, sym2SoP vnm) - toNumSoP' f (ValueExp (IntValue iv)) = - pure (1, int2SoP $ getIntVal iv `div` f) - where - getIntVal :: IntValue -> Integer - getIntVal (Int8Value v) = fromIntegral v - getIntVal (Int16Value v) = fromIntegral v - getIntVal (Int32Value v) = fromIntegral v - getIntVal (Int64Value v) = fromIntegral v - toNumSoP' f (UnOpExp Complement {} x) = do - (f', x_sop) <- toNumSoP' f x - pure (f', negSoP x_sop) - toNumSoP' f (BinOpExp Add {} x y) = do - (x_f, x_sop) <- toNumSoP x - (y_f, y_sop) <- toNumSoP y - let l_c_m = lcm x_f y_f - (x_m, y_m) = (l_c_m `div` x_f, l_c_m `div` y_f) - x_sop' = mulSoPs (int2SoP x_m) x_sop - y_sop' = mulSoPs (int2SoP y_m) y_sop - pure (f * l_c_m, addSoPs x_sop' y_sop') - toNumSoP' f (BinOpExp Sub {} x y) = do - (x_f, x_sop) <- toNumSoP x - (y_f, y_sop) <- toNumSoP y - let l_c_m = lcm x_f y_f - (x_m, y_m) = (l_c_m `div` x_f, l_c_m `div` y_f) - x_sop' = mulSoPs (int2SoP x_m) x_sop - n_y_sop' = mulSoPs (int2SoP (-y_m)) y_sop - pure (f * l_c_m, addSoPs x_sop' n_y_sop') - toNumSoP' f pe@(BinOpExp Mul {} x y) = do - (x_f, x_sop) <- toNumSoP x - (y_f, y_sop) <- toNumSoP y - case (x_f, y_f) of - (1, 1) -> pure (f, mulSoPs x_sop y_sop) - _ -> do - x' <- lookupUntransPE pe - toNumSoP' f $ LeafExp x' $ primExpType pe - -- pe / 1 == pe - toNumSoP' f (BinOpExp divish pe q) - | divideIsh divish && oneIshExp q = - toNumSoP' f pe - -- evaluate `val_x / val_y` - toNumSoP' f (BinOpExp divish x y) - | divideIsh divish, - ValueExp v_x <- x, - ValueExp v_y <- y = do - let f' = v_x `vdiv` v_y - toNumSoP' f $ ValueExp f' - -- Trivial simplifications: - -- (y * v) / y = v and (u * y) / y = u - | divideIsh divish, - BinOpExp (Mul _ _) u v <- x, - (is_fst, is_snd) <- (u == y, v == y), - is_fst || is_snd = do - toNumSoP' f $ if is_fst then v else u - where - vdiv (IntValue (Int64Value x')) (IntValue (Int64Value y')) = - IntValue $ Int64Value (x' `div` y') - vdiv (IntValue (Int32Value x')) (IntValue (Int32Value y')) = - IntValue $ Int32Value (x' `div` y') - vdiv (IntValue (Int16Value x')) (IntValue (Int16Value y')) = - IntValue $ Int16Value (x' `div` y') - vdiv (IntValue (Int8Value x')) (IntValue (Int8Value y')) = - IntValue $ Int8Value (x' `div` y') - -- vdiv (FloatValue (Float32Value x)) (FloatValue (Float32Value y)) = - -- FloatValue $ Float32Value $ x / y - -- vdiv (FloatValue (Float64Value x)) (FloatValue (Float64Value y)) = - -- FloatValue $ Float64Value $ x / y - vdiv _ _ = error "In vdiv: illegal type for division!" - -- try heuristic for exact division - toNumSoP' f pe@(BinOpExp divish x y) - | divideIsh divish = do - (x_f, x_sop) <- toNumSoP x - (y_f, y_sop) <- toNumSoP y - case (x_f, y_f, divSoPs x_sop y_sop) of - (1, 1, Just res) -> pure (f, res) - _ -> do - x' <- lookupUntransPE pe - toNumSoP' f $ LeafExp x' $ primExpType pe - -- Anything that is not handled by specific cases of toNumSoP - -- is handled by this default procedure: - -- If the target `pe` is in the unknwon `env` - -- Then return thecorresponding binding - -- Else make a fresh symbol `v`, bind it in the environment - -- and return it. - toNumSoP' f pe = do - x <- lookupUntransPE pe - toNumSoP' f $ LeafExp x $ primExpType pe - --- | Translates from a 'SoP' representation to a 'PrimExp' representation. -fromNumSoP :: Ord u => SoP u -> PrimExp u -fromNumSoP sop = - foldr ((~+~) . fromTerm) (ValueExp $ IntValue $ intValue Int64 (0 :: Integer)) (sopToLists sop) - where - fromTerm (term, n) = - foldl (~*~) (ValueExp $ IntValue $ intValue Int64 n) $ - map fromSym term - --- | Translates a symbol into a 'PrimExp'. -fromSym :: u -> PrimExp u -fromSym sym = LeafExp sym $ IntType Int64 - --- | Translates a 'PrimExp' containing a (top-level) comparison --- operator into a 'SoP' representation such that @sop >= 0@. -toNumSoPCmp :: MonadSoP u m => PrimExp u -> m (Integer, SoP u >= 0) -toNumSoPCmp (CmpOpExp (CmpEq ptp) x y) - -- x = y => x - y = 0 - | IntType {} <- ptp = toNumSoP $ x ~-~ y -toNumSoPCmp (CmpOpExp lessop x y) - -- x < y => x + 1 <= y => y >= x + 1 => y - (x+1) >= 0 - | Just itp <- lthishType lessop = - toNumSoP $ y ~-~ (x ~+~ ValueExp (IntValue $ intValue itp (1 :: Integer))) - -- x <= y => y >= x => y - x >= 0 - | Just _ <- leqishType lessop = - toNumSoP $ y ~-~ x - where - lthishType (CmpSlt itp) = Just itp - lthishType (CmpUlt itp) = Just itp - lthishType _ = Nothing - leqishType (CmpUle itp) = Just itp - leqishType (CmpSle itp) = Just itp - leqishType _ = Nothing -toNumSoPCmp pe = toNumSoP pe - -{-- --- This is a more refined treatment, but probably --- an overkill (harmful if you get the type wrong) -fromSym unknowns sym - | Nothing <- M.lookup sym (dir unknowns) = - LeafExp sym $ IntType Integer - | Just pe1 <- M.lookup sym (dir unknowns), - IntType Integer <- primExpType pe1 = - pe1 -fromSym unknowns sym = - error ("Type error in fromSym: type of " ++ - show sym ++ " is not Integer") ---} diff --git a/src/Futhark/SoP/Util.hs b/src/Futhark/SoP/Util.hs index 8d4b9e6b0d..8118cb7911 100644 --- a/src/Futhark/SoP/Util.hs +++ b/src/Futhark/SoP/Util.hs @@ -1,9 +1,14 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeOperators #-} + module Futhark.SoP.Util ( anyM, allM, ifM, toMS, localS, + type (>=), + type (==), ) where @@ -11,6 +16,7 @@ import Control.Monad.State import Data.Foldable import Data.MultiSet (MultiSet) import Data.MultiSet qualified as MS +import GHC.TypeLits (Natural) ifM :: Monad m => m Bool -> m a -> m a -> m a ifM mb mt mf = do @@ -32,3 +38,9 @@ localS m = do a <- m put env pure a + +-- | A type label to indicate @a >= 0@. +type a >= (b :: Natural) = a + +-- | A type label to indicate @a = 0@. +type a == (b :: Natural) = a