diff --git a/src/Futhark/Optimise/Simplify/Engine.hs b/src/Futhark/Optimise/Simplify/Engine.hs index 65c703d498..a14a5adb43 100644 --- a/src/Futhark/Optimise/Simplify/Engine.hs +++ b/src/Futhark/Optimise/Simplify/Engine.hs @@ -54,7 +54,6 @@ module Futhark.Optimise.Simplify.Engine bindLParams, simplifyBody, ST.SymbolTable, - hoistStms, blockIf, blockMigrated, enterLoop, @@ -206,6 +205,18 @@ asksEngineEnv f = f <$> askEngineEnv askVtable :: SimpleM rep (ST.SymbolTable (Wise rep)) askVtable = asksEngineEnv envVtable +mkSubSimplify :: SimplifiableRep rep => SimpleM rep (SubSimplify (Wise rep)) +mkSubSimplify = do + (ops, env) <- ask + pure $ \body -> do + scope <- askScope + let env' = env {envVtable = ST.fromScope scope} + (x, _) <- modifyNameSource $ runSimpleM (f body) ops env' + pure x + where + f body = + simplifyBodyNoHoisting mempty (map (const mempty) (bodyResult body)) body + localVtable :: (ST.SymbolTable (Wise rep) -> ST.SymbolTable (Wise rep)) -> SimpleM rep a -> @@ -486,7 +497,8 @@ hoistStms rules block orig_stms final = do process usageInStm stm stms usage x = do vtable <- askVtable - res <- bottomUpSimplifyStm rules (vtable, usage) stm + ss <- mkSubSimplify + res <- bottomUpSimplifyStm ss rules (vtable, usage) stm case res of Nothing -- Nothing to optimise - see if hoistable. | block vtable usage stm -> @@ -517,7 +529,8 @@ hoistStms rules block orig_stms final = do stms_h' <- nonrecSimplifyStm stms_h vtable <- askVtable - simplified <- topDownSimplifyStm rules vtable stms_h' + ss <- mkSubSimplify + simplified <- topDownSimplifyStm ss rules vtable stms_h' case simplified of Just newstms -> do diff --git a/src/Futhark/Optimise/Simplify/Rule.hs b/src/Futhark/Optimise/Simplify/Rule.hs index c72d9ab557..3a6faf703b 100644 --- a/src/Futhark/Optimise/Simplify/Rule.hs +++ b/src/Futhark/Optimise/Simplify/Rule.hs @@ -16,6 +16,8 @@ module Futhark.Optimise.Simplify.Rule RuleM, cannotSimplify, liftMaybe, + SubSimplify, + subSimplify, -- * Rule definition Rule (..), @@ -53,19 +55,26 @@ module Futhark.Optimise.Simplify.Rule ) where +import Control.Monad.Reader import Control.Monad.State import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.Builder import Futhark.IR +-- | An action for recursively simplifying a body. +type SubSimplify rep = Body rep -> RuleM rep (Body rep) + +newtype RuleEnv rep = RuleEnv {envSubSimplify :: SubSimplify rep} + -- | The monad in which simplification rules are evaluated. -newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource Maybe) a) +newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource (ReaderT (RuleEnv rep) Maybe)) a) deriving ( Functor, Applicative, Monad, MonadFreshNames, + MonadReader (RuleEnv rep), HasScope rep, LocalScope rep ) @@ -84,19 +93,29 @@ instance (BuilderOps rep) => MonadBuilder (RuleM rep) where simplify :: Scope rep -> VNameSource -> + RuleEnv rep -> Rule rep -> Maybe (Stms rep, VNameSource) -simplify _ _ Skip = Nothing -simplify scope src (Simplify (RuleM m)) = - runStateT (runBuilderT_ m scope) src +simplify _ _ _ Skip = Nothing +simplify scope src env (Simplify (RuleM m)) = + runReaderT (runStateT (runBuilderT_ m scope) src) env +-- | Abort the current attempt at simplification. cannotSimplify :: RuleM rep a -cannotSimplify = RuleM $ lift $ lift Nothing +cannotSimplify = RuleM $ lift $ lift $ lift Nothing liftMaybe :: Maybe a -> RuleM rep a liftMaybe Nothing = cannotSimplify liftMaybe (Just x) = pure x +-- | Recursively apply the simplifier on this body, using the current +-- rulebook. This can be quite costly, so think carefully before +-- doing this. +subSimplify :: SubSimplify rep +subSimplify body = do + s <- asks envSubSimplify + s body + -- | An efficient way of encoding whether a simplification rule should even be attempted. data Rule rep = -- | Give it a shot. @@ -252,31 +271,6 @@ ruleBook topdowns bottomups = forOp RuleGeneric {} = True forOp _ = False --- | @simplifyStm lookup stm@ performs simplification of the --- binding @stm@. If simplification is possible, a replacement list --- of bindings is returned, that bind at least the same names as the --- original binding (and possibly more, for intermediate results). -topDownSimplifyStm :: - (MonadFreshNames m, HasScope rep m, PrettyRep rep) => - RuleBook rep -> - ST.SymbolTable rep -> - Stm rep -> - m (Maybe (Stms rep)) -topDownSimplifyStm = applyRules . bookTopDownRules - --- | @simplifyStm uses stm@ performs simplification of the binding --- @stm@. If simplification is possible, a replacement list of --- bindings is returned, that bind at least the same names as the --- original binding (and possibly more, for intermediate results). --- The first argument is the set of names used after this binding. -bottomUpSimplifyStm :: - (MonadFreshNames m, HasScope rep m, PrettyRep rep) => - RuleBook rep -> - (ST.SymbolTable rep, UT.UsageTable) -> - Stm rep -> - m (Maybe (Stms rep)) -bottomUpSimplifyStm = applyRules . bookBottomUpRules - rulesForStm :: Stm rep -> Rules rep a -> [SimplificationRule rep a] rulesForStm stm = case stmExp stm of BasicOp {} -> rulesBasicOp @@ -299,19 +293,47 @@ applyRule _ _ _ = applyRules :: (MonadFreshNames m, HasScope rep m, PrettyRep rep) => + SubSimplify rep -> Rules rep a -> a -> Stm rep -> m (Maybe (Stms rep)) -applyRules all_rules context stm = do +applyRules ss all_rules context stm = do scope <- askScope - + let env = RuleEnv ss modifyNameSource $ \src -> let applyRules' [] = Nothing applyRules' (rule : rules) = - case simplify scope src (applyRule rule context stm) of + case simplify scope src env (applyRule rule context stm) of Just x -> Just x Nothing -> applyRules' rules in case applyRules' $ rulesForStm stm all_rules of Just (stms, src') -> (Just stms, src') Nothing -> (Nothing, src) + +-- | @simplifyStm lookup stm@ performs simplification of the +-- binding @stm@. If simplification is possible, a replacement list +-- of bindings is returned, that bind at least the same names as the +-- original binding (and possibly more, for intermediate results). +topDownSimplifyStm :: + (MonadFreshNames m, HasScope rep m, PrettyRep rep) => + SubSimplify rep -> + RuleBook rep -> + ST.SymbolTable rep -> + Stm rep -> + m (Maybe (Stms rep)) +topDownSimplifyStm ss = applyRules ss . bookTopDownRules + +-- | @simplifyStm uses stm@ performs simplification of the binding +-- @stm@. If simplification is possible, a replacement list of +-- bindings is returned, that bind at least the same names as the +-- original binding (and possibly more, for intermediate results). +-- The first argument is the set of names used after this binding. +bottomUpSimplifyStm :: + (MonadFreshNames m, HasScope rep m, PrettyRep rep) => + SubSimplify rep -> + RuleBook rep -> + (ST.SymbolTable rep, UT.UsageTable) -> + Stm rep -> + m (Maybe (Stms rep)) +bottomUpSimplifyStm ss = applyRules ss . bookBottomUpRules diff --git a/src/Futhark/Optimise/Simplify/Rules/Loop.hs b/src/Futhark/Optimise/Simplify/Rules/Loop.hs index 1f0d08c953..b1e59f4695 100644 --- a/src/Futhark/Optimise/Simplify/Rules/Loop.hs +++ b/src/Futhark/Optimise/Simplify/Rules/Loop.hs @@ -2,8 +2,9 @@ module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where import Control.Monad -import Data.Bifunctor (second) +import Data.Bifunctor (first, second) import Data.List (partition) +import Data.Map qualified as M import Data.Maybe import Futhark.Analysis.DataDependencies import Futhark.Analysis.PrimExp.Convert @@ -83,6 +84,79 @@ removeRedundantMergeVariables (_, used) pat aux (merge, form, body) removeRedundantMergeVariables _ _ _ _ = Skip +-- For a loop of the form +-- +-- loop p = x ... +-- ...stms... +-- in res +-- +-- we construct and simplify the body +-- +-- let p = x +-- ...stms... +-- in res +-- +-- and if that simplifies to 'x', then we conclude that the loop +-- parameter 'p' must be invariant to the loop and simply bind it (and +-- the loop result) to 'x'. +-- +-- Complication: for multi-parameter loops, we must also check that +-- the *original* computation of 'res' does *only* depends on other +-- invariant loop parameters. See tests/loops/invariant1.fut for an +-- example. +simplifyInvariantParams :: BuilderOps rep => TopDownRuleDoLoop rep +simplifyInvariantParams _vtable pat aux (params, form, loopbody) + | consts <- filter constInit params, + not $ null consts = Simplify . auxing aux $ + localScope (scopeOfFParams (map fst params) <> scopeOf form) $ do + loopbody_simpl <- subSimplify <=< buildBody_ $ do + mapM_ bindParam consts + bodyBind loopbody + let inv_pnames = determineInvariant $ bodyResult loopbody_simpl + invariant (_, (p, _), _) = paramName p `elem` inv_pnames + (inv, var) = + partition invariant $ + zip3 (patElems pat) params (bodyResult loopbody) + (var_pes, var_params, var_res) = unzip3 var + when (null inv) cannotSimplify + mapM_ bindInv inv + loopbody' <- mkBodyM (bodyStms loopbody) var_res + letBind (Pat var_pes) $ DoLoop var_params form loopbody' + | otherwise = Skip + where + loopbody_deps = dataDependencies loopbody + resDep (Var v) = oneName v <> fromMaybe mempty (M.lookup v loopbody_deps) + resDep _ = mempty + res_deps = map (resDep . resSubExp) $ bodyResult loopbody + + constInit (_, Constant {}) = True + constInit _ = False + + bindParam (p, se) = letBindNames [paramName p] $ BasicOp $ SubExp se + + bindInv (pe, (p, se), _) = do + letBindNames [patElemName pe] $ BasicOp $ SubExp se + letBindNames [paramName p] $ BasicOp $ SubExp se + + resIsInvariant ((_, x), x') = x == resSubExp x' + + depOnVar var (_, deps) = any (`nameIn` deps) var + + noInvDepOnVar inv var + | (inv_var, inv') <- partition (depOnVar var) inv, + not $ null inv_var = + noInvDepOnVar inv' $ map fst inv_var <> var + | otherwise = + map fst inv + + determineInvariant simpl_res = + let (inv, var) = + partition (resIsInvariant . fst) $ + zip (zip (map (first paramName) params) simpl_res) res_deps + in noInvDepOnVar + (map (first (fst . fst)) inv) + (map (fst . fst . fst) var) + -- We may change the type of the loop if we hoist out a shape -- annotation, in which case we also need to tweak the bound pattern. hoistLoopInvariantMergeVariables :: BuilderOps rep => TopDownRuleDoLoop rep @@ -290,7 +364,8 @@ topDownRules = [ RuleDoLoop hoistLoopInvariantMergeVariables, RuleDoLoop simplifyClosedFormLoop, RuleDoLoop simplifyKnownIterationLoop, - RuleDoLoop simplifyLoopVariables + RuleDoLoop simplifyLoopVariables, + RuleDoLoop simplifyInvariantParams ] bottomUpRules :: BuilderOps rep => [BottomUpRule rep] diff --git a/tests/loops/invariant0.fut b/tests/loops/invariant0.fut new file mode 100644 index 0000000000..c74da05f36 --- /dev/null +++ b/tests/loops/invariant0.fut @@ -0,0 +1,12 @@ +-- Removal of invariant of invariant loop parameter (and eventually entire loop). +-- == +-- structure { DoLoop 0 } + +entry main [n] (bs: [n]bool) = + let res = + loop (x, y) = (0i32, false) + for i < n do + let y' = bs[i] && y + let x' = x + (i32.bool y') + in (x', y') + in res diff --git a/tests/loops/invariant1.fut b/tests/loops/invariant1.fut new file mode 100644 index 0000000000..404acea9c3 --- /dev/null +++ b/tests/loops/invariant1.fut @@ -0,0 +1,13 @@ +-- Not actually invariant if you look carefully! +-- == +-- input { 0 } output { 0 false } +-- input { 4 } output { 3 true } + +entry main (n: i32) = + let res = + loop (x, y) = (0i32, false) + for _i < n do + let x' = if y then x + 1 else x + let y' = y || true + in (x', y') + in res diff --git a/tests/fibloop.fut b/tests/loops/invariant2.fut similarity index 70% rename from tests/fibloop.fut rename to tests/loops/invariant2.fut index 6441ceaa8e..49805d6f84 100644 --- a/tests/fibloop.fut +++ b/tests/loops/invariant2.fut @@ -1,10 +1,8 @@ +-- Also not actually invariant. -- == -- input { 0 } output { 1 } -- input { 10 } output { 89 } - -def fib(n: i32): i32 = +entry main (n: i32) = let (x,_) = loop (x, y) = (1,1) for _i < n do (y, x+y) in x - -def main(n: i32): i32 = fib(n)