Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new simplification rule for invariant loop parameters. #1990

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/Futhark/Optimise/Simplify/Engine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ module Futhark.Optimise.Simplify.Engine
bindLParams,
simplifyBody,
ST.SymbolTable,
hoistStms,
blockIf,
blockMigrated,
enterLoop,
Expand Down Expand Up @@ -206,6 +205,18 @@ asksEngineEnv f = f <$> askEngineEnv
askVtable :: SimpleM rep (ST.SymbolTable (Wise rep))
askVtable = asksEngineEnv envVtable

mkSubSimplify :: SimplifiableRep rep => SimpleM rep (SubSimplify (Wise rep))
mkSubSimplify = do
(ops, env) <- ask
pure $ \body -> do
scope <- askScope
let env' = env {envVtable = ST.fromScope scope}
(x, _) <- modifyNameSource $ runSimpleM (f body) ops env'
pure x
where
f body =
simplifyBodyNoHoisting mempty (map (const mempty) (bodyResult body)) body

localVtable ::
(ST.SymbolTable (Wise rep) -> ST.SymbolTable (Wise rep)) ->
SimpleM rep a ->
Expand Down Expand Up @@ -486,7 +497,8 @@ hoistStms rules block orig_stms final = do

process usageInStm stm stms usage x = do
vtable <- askVtable
res <- bottomUpSimplifyStm rules (vtable, usage) stm
ss <- mkSubSimplify
res <- bottomUpSimplifyStm ss rules (vtable, usage) stm
case res of
Nothing -- Nothing to optimise - see if hoistable.
| block vtable usage stm ->
Expand Down Expand Up @@ -517,7 +529,8 @@ hoistStms rules block orig_stms final = do
stms_h' <- nonrecSimplifyStm stms_h

vtable <- askVtable
simplified <- topDownSimplifyStm rules vtable stms_h'
ss <- mkSubSimplify
simplified <- topDownSimplifyStm ss rules vtable stms_h'

case simplified of
Just newstms -> do
Expand Down
88 changes: 55 additions & 33 deletions src/Futhark/Optimise/Simplify/Rule.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ module Futhark.Optimise.Simplify.Rule
RuleM,
cannotSimplify,
liftMaybe,
SubSimplify,
subSimplify,

-- * Rule definition
Rule (..),
Expand Down Expand Up @@ -53,19 +55,26 @@ module Futhark.Optimise.Simplify.Rule
)
where

import Control.Monad.Reader
import Control.Monad.State
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Builder
import Futhark.IR

-- | An action for recursively simplifying a body.
type SubSimplify rep = Body rep -> RuleM rep (Body rep)

newtype RuleEnv rep = RuleEnv {envSubSimplify :: SubSimplify rep}

-- | The monad in which simplification rules are evaluated.
newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource Maybe) a)
newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource (ReaderT (RuleEnv rep) Maybe)) a)
deriving
( Functor,
Applicative,
Monad,
MonadFreshNames,
MonadReader (RuleEnv rep),
HasScope rep,
LocalScope rep
)
Expand All @@ -84,19 +93,29 @@ instance (BuilderOps rep) => MonadBuilder (RuleM rep) where
simplify ::
Scope rep ->
VNameSource ->
RuleEnv rep ->
Rule rep ->
Maybe (Stms rep, VNameSource)
simplify _ _ Skip = Nothing
simplify scope src (Simplify (RuleM m)) =
runStateT (runBuilderT_ m scope) src
simplify _ _ _ Skip = Nothing
simplify scope src env (Simplify (RuleM m)) =
runReaderT (runStateT (runBuilderT_ m scope) src) env

-- | Abort the current attempt at simplification.
cannotSimplify :: RuleM rep a
cannotSimplify = RuleM $ lift $ lift Nothing
cannotSimplify = RuleM $ lift $ lift $ lift Nothing

liftMaybe :: Maybe a -> RuleM rep a
liftMaybe Nothing = cannotSimplify
liftMaybe (Just x) = pure x

-- | Recursively apply the simplifier on this body, using the current
-- rulebook. This can be quite costly, so think carefully before
-- doing this.
subSimplify :: SubSimplify rep
subSimplify body = do
s <- asks envSubSimplify
s body

-- | An efficient way of encoding whether a simplification rule should even be attempted.
data Rule rep
= -- | Give it a shot.
Expand Down Expand Up @@ -252,31 +271,6 @@ ruleBook topdowns bottomups =
forOp RuleGeneric {} = True
forOp _ = False

-- | @simplifyStm lookup stm@ performs simplification of the
-- binding @stm@. If simplification is possible, a replacement list
-- of bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
topDownSimplifyStm ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
RuleBook rep ->
ST.SymbolTable rep ->
Stm rep ->
m (Maybe (Stms rep))
topDownSimplifyStm = applyRules . bookTopDownRules

-- | @simplifyStm uses stm@ performs simplification of the binding
-- @stm@. If simplification is possible, a replacement list of
-- bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
-- The first argument is the set of names used after this binding.
bottomUpSimplifyStm ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
RuleBook rep ->
(ST.SymbolTable rep, UT.UsageTable) ->
Stm rep ->
m (Maybe (Stms rep))
bottomUpSimplifyStm = applyRules . bookBottomUpRules

rulesForStm :: Stm rep -> Rules rep a -> [SimplificationRule rep a]
rulesForStm stm = case stmExp stm of
BasicOp {} -> rulesBasicOp
Expand All @@ -299,19 +293,47 @@ applyRule _ _ _ =

applyRules ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
SubSimplify rep ->
Rules rep a ->
a ->
Stm rep ->
m (Maybe (Stms rep))
applyRules all_rules context stm = do
applyRules ss all_rules context stm = do
scope <- askScope

let env = RuleEnv ss
modifyNameSource $ \src ->
let applyRules' [] = Nothing
applyRules' (rule : rules) =
case simplify scope src (applyRule rule context stm) of
case simplify scope src env (applyRule rule context stm) of
Just x -> Just x
Nothing -> applyRules' rules
in case applyRules' $ rulesForStm stm all_rules of
Just (stms, src') -> (Just stms, src')
Nothing -> (Nothing, src)

-- | @simplifyStm lookup stm@ performs simplification of the
-- binding @stm@. If simplification is possible, a replacement list
-- of bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
topDownSimplifyStm ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
SubSimplify rep ->
RuleBook rep ->
ST.SymbolTable rep ->
Stm rep ->
m (Maybe (Stms rep))
topDownSimplifyStm ss = applyRules ss . bookTopDownRules

-- | @simplifyStm uses stm@ performs simplification of the binding
-- @stm@. If simplification is possible, a replacement list of
-- bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
-- The first argument is the set of names used after this binding.
bottomUpSimplifyStm ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
SubSimplify rep ->
RuleBook rep ->
(ST.SymbolTable rep, UT.UsageTable) ->
Stm rep ->
m (Maybe (Stms rep))
bottomUpSimplifyStm ss = applyRules ss . bookBottomUpRules
79 changes: 77 additions & 2 deletions src/Futhark/Optimise/Simplify/Rules/Loop.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where

import Control.Monad
import Data.Bifunctor (second)
import Data.Bifunctor (first, second)
import Data.List (partition)
import Data.Map qualified as M
import Data.Maybe
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.PrimExp.Convert
Expand Down Expand Up @@ -83,6 +84,79 @@ removeRedundantMergeVariables (_, used) pat aux (merge, form, body)
removeRedundantMergeVariables _ _ _ _ =
Skip

-- For a loop of the form
--
-- loop p = x ...
-- ...stms...
-- in res
--
-- we construct and simplify the body
--
-- let p = x
-- ...stms...
-- in res
--
-- and if that simplifies to 'x', then we conclude that the loop
-- parameter 'p' must be invariant to the loop and simply bind it (and
-- the loop result) to 'x'.
--
-- Complication: for multi-parameter loops, we must also check that
-- the *original* computation of 'res' does *only* depends on other
-- invariant loop parameters. See tests/loops/invariant1.fut for an
-- example.
simplifyInvariantParams :: BuilderOps rep => TopDownRuleLoop rep
simplifyInvariantParams _vtable pat aux (params, form, loopbody)
| consts <- filter constInit params,
not $ null consts = Simplify . auxing aux $
localScope (scopeOfFParams (map fst params) <> scopeOf form) $ do
loopbody_simpl <- subSimplify <=< buildBody_ $ do
mapM_ bindParam consts
bodyBind loopbody
let inv_pnames = determineInvariant $ bodyResult loopbody_simpl
invariant (_, (p, _), _) = paramName p `elem` inv_pnames
(inv, var) =
partition invariant $
zip3 (patElems pat) params (bodyResult loopbody)
(var_pes, var_params, var_res) = unzip3 var
when (null inv) cannotSimplify
mapM_ bindInv inv
loopbody' <- mkBodyM (bodyStms loopbody) var_res
letBind (Pat var_pes) $ Loop var_params form loopbody'
| otherwise = Skip
where
loopbody_deps = dataDependencies loopbody
resDep (Var v) = oneName v <> fromMaybe mempty (M.lookup v loopbody_deps)
resDep _ = mempty
res_deps = map (resDep . resSubExp) $ bodyResult loopbody

constInit (_, Constant {}) = True
constInit _ = False

bindParam (p, se) = letBindNames [paramName p] $ BasicOp $ SubExp se

bindInv (pe, (p, se), _) = do
letBindNames [patElemName pe] $ BasicOp $ SubExp se
letBindNames [paramName p] $ BasicOp $ SubExp se

resIsInvariant ((_, x), x') = x == resSubExp x'

depOnVar var (_, deps) = any (`nameIn` deps) var

noInvDepOnVar inv var
| (inv_var, inv') <- partition (depOnVar var) inv,
not $ null inv_var =
noInvDepOnVar inv' $ map fst inv_var <> var
| otherwise =
map fst inv

determineInvariant simpl_res =
let (inv, var) =
partition (resIsInvariant . fst) $
zip (zip (map (first paramName) params) simpl_res) res_deps
in noInvDepOnVar
(map (first (fst . fst)) inv)
(map (fst . fst . fst) var)

-- We may change the type of the loop if we hoist out a shape
-- annotation, in which case we also need to tweak the bound pattern.
hoistLoopInvariantMergeVariables :: BuilderOps rep => TopDownRuleLoop rep
Expand Down Expand Up @@ -290,7 +364,8 @@ topDownRules =
[ RuleLoop hoistLoopInvariantMergeVariables,
RuleLoop simplifyClosedFormLoop,
RuleLoop simplifyKnownIterationLoop,
RuleLoop simplifyLoopVariables
RuleLoop simplifyLoopVariables,
RuleLoop simplifyInvariantParams
]

bottomUpRules :: BuilderOps rep => [BottomUpRule rep]
Expand Down
12 changes: 12 additions & 0 deletions tests/loops/invariant0.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- Removal of invariant of invariant loop parameter (and eventually entire loop).
-- ==
-- structure { DoLoop 0 }

entry main [n] (bs: [n]bool) =
let res =
loop (x, y) = (0i32, false)
for i < n do
let y' = bs[i] && y
let x' = x + (i32.bool y')
in (x', y')
in res
13 changes: 13 additions & 0 deletions tests/loops/invariant1.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- Not actually invariant if you look carefully!
-- ==
-- input { 0 } output { 0 false }
-- input { 4 } output { 3 true }

entry main (n: i32) =
let res =
loop (x, y) = (0i32, false)
for _i < n do
let x' = if y then x + 1 else x
let y' = y || true
in (x', y')
in res
6 changes: 2 additions & 4 deletions tests/fibloop.fut → tests/loops/invariant2.fut
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
-- Also not actually invariant.
-- ==
-- input { 0 } output { 1 }
-- input { 10 } output { 89 }


def fib(n: i32): i32 =
entry main (n: i32) =
let (x,_) = loop (x, y) = (1,1) for _i < n do (y, x+y)
in x

def main(n: i32): i32 = fib(n)
Loading