From ff4969e43b592774fb36f7a10d471af4c96a1ece Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 20 Jul 2023 17:04:09 +0200 Subject: [PATCH] Add new simplification rule for invariant loop parameters. This was suggested by Cosmin to address some of the code produced by AD. The idea is that for a loop of the form loop p = x ... ...stms... in res we construct and simplify the body let p = x ...stms... in res and if that simplifies to 'x', then we conclude that the loop parameter 'p' must be invariant to the loop and simply bind it (and the loop result) to 'x'. Complication: for multi-parameter loops, we must also check that the *original* computation of 'res' does *only* depends on other invariant loop parameters. Currently we do this only for loops that have a constant as one of their initial loop parameter values. The main downside of this rule is that doing recursive simplification is quite expensive. Especially after sequentialisation, pretty much every 'reduce' will have been turned into a loop that triggers this rule (although the rule itself will fail in most cases, after doing the simplification). Therefore I'm a bit hesitant to enable it as is. Sure, the Futhark compiler is slow and it was never meant to be fast, but it is still quite easy for the compiler to become *uselessly slow* if we are not careful. E.g. on OptionPricing, this rule itself makes compilation 10% slower (and does not actually optimise anything - this is purely the cost of failing checks). --- src/Futhark/Optimise/Simplify/Engine.hs | 19 ++++- src/Futhark/Optimise/Simplify/Rule.hs | 88 +++++++++++++-------- src/Futhark/Optimise/Simplify/Rules/Loop.hs | 79 +++++++++++++++++- tests/loops/invariant0.fut | 12 +++ tests/loops/invariant1.fut | 13 +++ tests/{fibloop.fut => loops/invariant2.fut} | 6 +- 6 files changed, 175 insertions(+), 42 deletions(-) create mode 100644 tests/loops/invariant0.fut create mode 100644 tests/loops/invariant1.fut rename tests/{fibloop.fut => loops/invariant2.fut} (70%) diff --git a/src/Futhark/Optimise/Simplify/Engine.hs b/src/Futhark/Optimise/Simplify/Engine.hs index 65c703d498..a14a5adb43 100644 --- a/src/Futhark/Optimise/Simplify/Engine.hs +++ b/src/Futhark/Optimise/Simplify/Engine.hs @@ -54,7 +54,6 @@ module Futhark.Optimise.Simplify.Engine bindLParams, simplifyBody, ST.SymbolTable, - hoistStms, blockIf, blockMigrated, enterLoop, @@ -206,6 +205,18 @@ asksEngineEnv f = f <$> askEngineEnv askVtable :: SimpleM rep (ST.SymbolTable (Wise rep)) askVtable = asksEngineEnv envVtable +mkSubSimplify :: SimplifiableRep rep => SimpleM rep (SubSimplify (Wise rep)) +mkSubSimplify = do + (ops, env) <- ask + pure $ \body -> do + scope <- askScope + let env' = env {envVtable = ST.fromScope scope} + (x, _) <- modifyNameSource $ runSimpleM (f body) ops env' + pure x + where + f body = + simplifyBodyNoHoisting mempty (map (const mempty) (bodyResult body)) body + localVtable :: (ST.SymbolTable (Wise rep) -> ST.SymbolTable (Wise rep)) -> SimpleM rep a -> @@ -486,7 +497,8 @@ hoistStms rules block orig_stms final = do process usageInStm stm stms usage x = do vtable <- askVtable - res <- bottomUpSimplifyStm rules (vtable, usage) stm + ss <- mkSubSimplify + res <- bottomUpSimplifyStm ss rules (vtable, usage) stm case res of Nothing -- Nothing to optimise - see if hoistable. | block vtable usage stm -> @@ -517,7 +529,8 @@ hoistStms rules block orig_stms final = do stms_h' <- nonrecSimplifyStm stms_h vtable <- askVtable - simplified <- topDownSimplifyStm rules vtable stms_h' + ss <- mkSubSimplify + simplified <- topDownSimplifyStm ss rules vtable stms_h' case simplified of Just newstms -> do diff --git a/src/Futhark/Optimise/Simplify/Rule.hs b/src/Futhark/Optimise/Simplify/Rule.hs index c72d9ab557..3a6faf703b 100644 --- a/src/Futhark/Optimise/Simplify/Rule.hs +++ b/src/Futhark/Optimise/Simplify/Rule.hs @@ -16,6 +16,8 @@ module Futhark.Optimise.Simplify.Rule RuleM, cannotSimplify, liftMaybe, + SubSimplify, + subSimplify, -- * Rule definition Rule (..), @@ -53,19 +55,26 @@ module Futhark.Optimise.Simplify.Rule ) where +import Control.Monad.Reader import Control.Monad.State import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.Builder import Futhark.IR +-- | An action for recursively simplifying a body. +type SubSimplify rep = Body rep -> RuleM rep (Body rep) + +newtype RuleEnv rep = RuleEnv {envSubSimplify :: SubSimplify rep} + -- | The monad in which simplification rules are evaluated. -newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource Maybe) a) +newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource (ReaderT (RuleEnv rep) Maybe)) a) deriving ( Functor, Applicative, Monad, MonadFreshNames, + MonadReader (RuleEnv rep), HasScope rep, LocalScope rep ) @@ -84,19 +93,29 @@ instance (BuilderOps rep) => MonadBuilder (RuleM rep) where simplify :: Scope rep -> VNameSource -> + RuleEnv rep -> Rule rep -> Maybe (Stms rep, VNameSource) -simplify _ _ Skip = Nothing -simplify scope src (Simplify (RuleM m)) = - runStateT (runBuilderT_ m scope) src +simplify _ _ _ Skip = Nothing +simplify scope src env (Simplify (RuleM m)) = + runReaderT (runStateT (runBuilderT_ m scope) src) env +-- | Abort the current attempt at simplification. cannotSimplify :: RuleM rep a -cannotSimplify = RuleM $ lift $ lift Nothing +cannotSimplify = RuleM $ lift $ lift $ lift Nothing liftMaybe :: Maybe a -> RuleM rep a liftMaybe Nothing = cannotSimplify liftMaybe (Just x) = pure x +-- | Recursively apply the simplifier on this body, using the current +-- rulebook. This can be quite costly, so think carefully before +-- doing this. +subSimplify :: SubSimplify rep +subSimplify body = do + s <- asks envSubSimplify + s body + -- | An efficient way of encoding whether a simplification rule should even be attempted. data Rule rep = -- | Give it a shot. @@ -252,31 +271,6 @@ ruleBook topdowns bottomups = forOp RuleGeneric {} = True forOp _ = False --- | @simplifyStm lookup stm@ performs simplification of the --- binding @stm@. If simplification is possible, a replacement list --- of bindings is returned, that bind at least the same names as the --- original binding (and possibly more, for intermediate results). -topDownSimplifyStm :: - (MonadFreshNames m, HasScope rep m, PrettyRep rep) => - RuleBook rep -> - ST.SymbolTable rep -> - Stm rep -> - m (Maybe (Stms rep)) -topDownSimplifyStm = applyRules . bookTopDownRules - --- | @simplifyStm uses stm@ performs simplification of the binding --- @stm@. If simplification is possible, a replacement list of --- bindings is returned, that bind at least the same names as the --- original binding (and possibly more, for intermediate results). --- The first argument is the set of names used after this binding. -bottomUpSimplifyStm :: - (MonadFreshNames m, HasScope rep m, PrettyRep rep) => - RuleBook rep -> - (ST.SymbolTable rep, UT.UsageTable) -> - Stm rep -> - m (Maybe (Stms rep)) -bottomUpSimplifyStm = applyRules . bookBottomUpRules - rulesForStm :: Stm rep -> Rules rep a -> [SimplificationRule rep a] rulesForStm stm = case stmExp stm of BasicOp {} -> rulesBasicOp @@ -299,19 +293,47 @@ applyRule _ _ _ = applyRules :: (MonadFreshNames m, HasScope rep m, PrettyRep rep) => + SubSimplify rep -> Rules rep a -> a -> Stm rep -> m (Maybe (Stms rep)) -applyRules all_rules context stm = do +applyRules ss all_rules context stm = do scope <- askScope - + let env = RuleEnv ss modifyNameSource $ \src -> let applyRules' [] = Nothing applyRules' (rule : rules) = - case simplify scope src (applyRule rule context stm) of + case simplify scope src env (applyRule rule context stm) of Just x -> Just x Nothing -> applyRules' rules in case applyRules' $ rulesForStm stm all_rules of Just (stms, src') -> (Just stms, src') Nothing -> (Nothing, src) + +-- | @simplifyStm lookup stm@ performs simplification of the +-- binding @stm@. If simplification is possible, a replacement list +-- of bindings is returned, that bind at least the same names as the +-- original binding (and possibly more, for intermediate results). +topDownSimplifyStm :: + (MonadFreshNames m, HasScope rep m, PrettyRep rep) => + SubSimplify rep -> + RuleBook rep -> + ST.SymbolTable rep -> + Stm rep -> + m (Maybe (Stms rep)) +topDownSimplifyStm ss = applyRules ss . bookTopDownRules + +-- | @simplifyStm uses stm@ performs simplification of the binding +-- @stm@. If simplification is possible, a replacement list of +-- bindings is returned, that bind at least the same names as the +-- original binding (and possibly more, for intermediate results). +-- The first argument is the set of names used after this binding. +bottomUpSimplifyStm :: + (MonadFreshNames m, HasScope rep m, PrettyRep rep) => + SubSimplify rep -> + RuleBook rep -> + (ST.SymbolTable rep, UT.UsageTable) -> + Stm rep -> + m (Maybe (Stms rep)) +bottomUpSimplifyStm ss = applyRules ss . bookBottomUpRules diff --git a/src/Futhark/Optimise/Simplify/Rules/Loop.hs b/src/Futhark/Optimise/Simplify/Rules/Loop.hs index 1f0d08c953..b1e59f4695 100644 --- a/src/Futhark/Optimise/Simplify/Rules/Loop.hs +++ b/src/Futhark/Optimise/Simplify/Rules/Loop.hs @@ -2,8 +2,9 @@ module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where import Control.Monad -import Data.Bifunctor (second) +import Data.Bifunctor (first, second) import Data.List (partition) +import Data.Map qualified as M import Data.Maybe import Futhark.Analysis.DataDependencies import Futhark.Analysis.PrimExp.Convert @@ -83,6 +84,79 @@ removeRedundantMergeVariables (_, used) pat aux (merge, form, body) removeRedundantMergeVariables _ _ _ _ = Skip +-- For a loop of the form +-- +-- loop p = x ... +-- ...stms... +-- in res +-- +-- we construct and simplify the body +-- +-- let p = x +-- ...stms... +-- in res +-- +-- and if that simplifies to 'x', then we conclude that the loop +-- parameter 'p' must be invariant to the loop and simply bind it (and +-- the loop result) to 'x'. +-- +-- Complication: for multi-parameter loops, we must also check that +-- the *original* computation of 'res' does *only* depends on other +-- invariant loop parameters. See tests/loops/invariant1.fut for an +-- example. +simplifyInvariantParams :: BuilderOps rep => TopDownRuleDoLoop rep +simplifyInvariantParams _vtable pat aux (params, form, loopbody) + | consts <- filter constInit params, + not $ null consts = Simplify . auxing aux $ + localScope (scopeOfFParams (map fst params) <> scopeOf form) $ do + loopbody_simpl <- subSimplify <=< buildBody_ $ do + mapM_ bindParam consts + bodyBind loopbody + let inv_pnames = determineInvariant $ bodyResult loopbody_simpl + invariant (_, (p, _), _) = paramName p `elem` inv_pnames + (inv, var) = + partition invariant $ + zip3 (patElems pat) params (bodyResult loopbody) + (var_pes, var_params, var_res) = unzip3 var + when (null inv) cannotSimplify + mapM_ bindInv inv + loopbody' <- mkBodyM (bodyStms loopbody) var_res + letBind (Pat var_pes) $ DoLoop var_params form loopbody' + | otherwise = Skip + where + loopbody_deps = dataDependencies loopbody + resDep (Var v) = oneName v <> fromMaybe mempty (M.lookup v loopbody_deps) + resDep _ = mempty + res_deps = map (resDep . resSubExp) $ bodyResult loopbody + + constInit (_, Constant {}) = True + constInit _ = False + + bindParam (p, se) = letBindNames [paramName p] $ BasicOp $ SubExp se + + bindInv (pe, (p, se), _) = do + letBindNames [patElemName pe] $ BasicOp $ SubExp se + letBindNames [paramName p] $ BasicOp $ SubExp se + + resIsInvariant ((_, x), x') = x == resSubExp x' + + depOnVar var (_, deps) = any (`nameIn` deps) var + + noInvDepOnVar inv var + | (inv_var, inv') <- partition (depOnVar var) inv, + not $ null inv_var = + noInvDepOnVar inv' $ map fst inv_var <> var + | otherwise = + map fst inv + + determineInvariant simpl_res = + let (inv, var) = + partition (resIsInvariant . fst) $ + zip (zip (map (first paramName) params) simpl_res) res_deps + in noInvDepOnVar + (map (first (fst . fst)) inv) + (map (fst . fst . fst) var) + -- We may change the type of the loop if we hoist out a shape -- annotation, in which case we also need to tweak the bound pattern. hoistLoopInvariantMergeVariables :: BuilderOps rep => TopDownRuleDoLoop rep @@ -290,7 +364,8 @@ topDownRules = [ RuleDoLoop hoistLoopInvariantMergeVariables, RuleDoLoop simplifyClosedFormLoop, RuleDoLoop simplifyKnownIterationLoop, - RuleDoLoop simplifyLoopVariables + RuleDoLoop simplifyLoopVariables, + RuleDoLoop simplifyInvariantParams ] bottomUpRules :: BuilderOps rep => [BottomUpRule rep] diff --git a/tests/loops/invariant0.fut b/tests/loops/invariant0.fut new file mode 100644 index 0000000000..c74da05f36 --- /dev/null +++ b/tests/loops/invariant0.fut @@ -0,0 +1,12 @@ +-- Removal of invariant of invariant loop parameter (and eventually entire loop). +-- == +-- structure { DoLoop 0 } + +entry main [n] (bs: [n]bool) = + let res = + loop (x, y) = (0i32, false) + for i < n do + let y' = bs[i] && y + let x' = x + (i32.bool y') + in (x', y') + in res diff --git a/tests/loops/invariant1.fut b/tests/loops/invariant1.fut new file mode 100644 index 0000000000..404acea9c3 --- /dev/null +++ b/tests/loops/invariant1.fut @@ -0,0 +1,13 @@ +-- Not actually invariant if you look carefully! +-- == +-- input { 0 } output { 0 false } +-- input { 4 } output { 3 true } + +entry main (n: i32) = + let res = + loop (x, y) = (0i32, false) + for _i < n do + let x' = if y then x + 1 else x + let y' = y || true + in (x', y') + in res diff --git a/tests/fibloop.fut b/tests/loops/invariant2.fut similarity index 70% rename from tests/fibloop.fut rename to tests/loops/invariant2.fut index 6441ceaa8e..49805d6f84 100644 --- a/tests/fibloop.fut +++ b/tests/loops/invariant2.fut @@ -1,10 +1,8 @@ +-- Also not actually invariant. -- == -- input { 0 } output { 1 } -- input { 10 } output { 89 } - -def fib(n: i32): i32 = +entry main (n: i32) = let (x,_) = loop (x, y) = (1,1) for _i < n do (y, x+y) in x - -def main(n: i32): i32 = fib(n)