Skip to content

Commit

Permalink
Generalize SoP system to be able to work on different expression
Browse files Browse the repository at this point in the history
representations and misc progress.
  • Loading branch information
zfnmxt committed Jun 14, 2023
1 parent 6fe09cc commit 088201e
Show file tree
Hide file tree
Showing 13 changed files with 460 additions and 532 deletions.
4 changes: 2 additions & 2 deletions futhark.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,14 @@ library
Futhark.Construct
Futhark.Doc.Generator
Futhark.Error
Futhark.SoP.Convert
Futhark.SoP.Expression
Futhark.SoP.Monad
Futhark.SoP.Refine
Futhark.SoP.RefineEquivs
Futhark.SoP.FourierMotzkin
Futhark.SoP.PrimExp
Futhark.SoP.RefineRanges
Futhark.SoP.SoP
Futhark.SoP.ToFromSoP
Futhark.SoP.Util
Futhark.FreshNames
Futhark.IR
Expand Down
115 changes: 6 additions & 109 deletions src/Futhark/Internalise/Refinement.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import Futhark.Analysis.PrimExp (PrimExp)
import Futhark.Analysis.PrimExp qualified as PE
import Futhark.Internalise.TypesValues (internalisePrimType, internalisePrimValue)
import Futhark.MonadFreshNames
import Futhark.SoP.Convert
import Futhark.SoP.FourierMotzkin
import Futhark.SoP.Monad
import Futhark.SoP.PrimExp
import Futhark.SoP.Refine
import Futhark.SoP.SoP
import Futhark.SoP.ToFromSoP
import Futhark.SoP.Util
import Futhark.Util.Pretty
import Language.Futhark
Expand All @@ -22,127 +21,25 @@ import Language.Futhark.Semantic hiding (Env)
type Env = ()

newtype RefineM a
= RefineM (SoPMT VName (RWS Env () VNameSource) a)
= RefineM (SoPMT VName Exp (RWS Env () VNameSource) a)
deriving
( Functor,
Applicative,
Monad,
MonadReader Env,
MonadSoP VName
MonadSoP VName Exp
)

instance MonadFreshNames RefineM where
getNameSource = RefineM $ getNameSource
putNameSource = RefineM . putNameSource

convertBinOp :: BinOp -> PrimExp VName -> PrimExp VName -> PrimType -> PrimType -> Maybe (PrimExp VName)
convertBinOp LogAnd x y Bool _ =
simpleBinOp PE.LogAnd x y
convertBinOp LogOr x y Bool _ =
simpleBinOp PE.LogOr x y
convertBinOp Plus x y (Signed t) _ =
simpleBinOp (PE.Add t PE.OverflowWrap) x y
convertBinOp Plus x y (Unsigned t) _ =
simpleBinOp (PE.Add t PE.OverflowWrap) x y
convertBinOp Plus x y (FloatType t) _ =
simpleBinOp (PE.FAdd t) x y
convertBinOp Minus x y (Signed t) _ =
simpleBinOp (PE.Sub t PE.OverflowWrap) x y
convertBinOp Minus x y (Unsigned t) _ =
simpleBinOp (PE.Sub t PE.OverflowWrap) x y
convertBinOp Minus x y (FloatType t) _ =
simpleBinOp (PE.FSub t) x y
convertBinOp Times x y (Signed t) _ =
simpleBinOp (PE.Mul t PE.OverflowWrap) x y
convertBinOp Times x y (Unsigned t) _ =
simpleBinOp (PE.Mul t PE.OverflowWrap) x y
convertBinOp Times x y (FloatType t) _ =
simpleBinOp (PE.FMul t) x y
convertBinOp Equal x y t _ =
simpleCmpOp (PE.CmpEq $ internalisePrimType t) x y
convertBinOp NotEqual x y t _ = do
Just $ PE.UnOpExp PE.Not $ PE.CmpOpExp (PE.CmpEq $ internalisePrimType t) x y
convertBinOp Less x y (Signed t) _ =
simpleCmpOp (PE.CmpSlt t) x y
convertBinOp Less x y (Unsigned t) _ =
simpleCmpOp (PE.CmpUlt t) x y
convertBinOp Leq x y (Signed t) _ =
simpleCmpOp (PE.CmpSle t) x y
convertBinOp Leq x y (Unsigned t) _ =
simpleCmpOp (PE.CmpUle t) x y
convertBinOp Greater x y (Signed t) _ =
simpleCmpOp (PE.CmpSlt t) y x -- Note the swapped x and y
convertBinOp Greater x y (Unsigned t) _ =
simpleCmpOp (PE.CmpUlt t) y x -- Note the swapped x and y
convertBinOp Geq x y (Signed t) _ =
simpleCmpOp (PE.CmpSle t) y x -- Note the swapped x and y
convertBinOp Geq x y (Unsigned t) _ =
simpleCmpOp (PE.CmpUle t) y x -- Note the swapped x and y
convertBinOp Less x y (FloatType t) _ =
simpleCmpOp (PE.FCmpLt t) x y
convertBinOp Leq x y (FloatType t) _ =
simpleCmpOp (PE.FCmpLe t) x y
convertBinOp Greater x y (FloatType t) _ =
simpleCmpOp (PE.FCmpLt t) y x -- Note the swapped x and y
convertBinOp Geq x y (FloatType t) _ =
simpleCmpOp (PE.FCmpLe t) y x -- Note the swapped x and y
convertBinOp Less x y Bool _ =
simpleCmpOp PE.CmpLlt x y
convertBinOp Leq x y Bool _ =
simpleCmpOp PE.CmpLle x y
convertBinOp Greater x y Bool _ =
simpleCmpOp PE.CmpLlt y x -- Note the swapped x and y
convertBinOp Geq x y Bool _ =
simpleCmpOp PE.CmpLle y x -- Note the swapped x and y
convertBinOp _ _ _ _ _ = Nothing

simpleBinOp op x y = Just $ PE.BinOpExp op x y

simpleCmpOp op x y = Just $ PE.CmpOpExp op x y

expToPrimExp :: Exp -> Maybe (PrimExp VName)
expToPrimExp (Literal v _) = Just $ PE.ValueExp $ internalisePrimValue v
expToPrimExp (IntLit v (Info t) _) =
case t of
Scalar (Prim (Signed it)) -> Just $ PE.ValueExp $ PE.IntValue $ PE.intValue it v
Scalar (Prim (Unsigned it)) -> Just $ PE.ValueExp $ PE.IntValue $ PE.intValue it v
Scalar (Prim (FloatType ft)) -> Just $ PE.ValueExp $ PE.FloatValue $ PE.floatValue ft v
_ -> Nothing
expToPrimExp (FloatLit v (Info t) _) =
case t of
Scalar (Prim (FloatType ft)) -> Just $ PE.ValueExp $ PE.FloatValue $ PE.floatValue ft v
_ -> Nothing
expToPrimExp (AppExp (BinOp (op, _) _ (e_x, _) (e_y, _) _) _) = do
x <- expToPrimExp e_x
y <- expToPrimExp e_y
guard $ baseTag (qualLeaf op) <= maxIntrinsicTag
let name = baseString $ qualLeaf op
bop <- find ((name ==) . prettyString) [minBound .. maxBound :: BinOp]
t_x <- getPrimType $ typeOf e_x
t_y <- getPrimType $ typeOf e_y
convertBinOp bop x y t_x t_y
where
getPrimType (Scalar (Prim t)) = Just t
getPrimType _ = Nothing
expToPrimExp _ = Nothing

checkExp :: Exp -> RefineM Bool
checkExp e =
case expToPrimExp e of
Just pe -> checkPrimExp pe
Nothing -> pure False

checkPrimExp :: PrimExp VName -> RefineM Bool
checkPrimExp (PE.BinOpExp PE.LogAnd x y) =
(&&) <$> checkPrimExp x <*> checkPrimExp y
checkPrimExp (PE.BinOpExp PE.LogOr x y) =
(||) <$> checkPrimExp x <*> checkPrimExp y
checkPrimExp pe@(PE.CmpOpExp cop x y) = do
(_, sop) <- toNumSoPCmp pe
checkExp e = do
(_, sop) <- toSoPCmp e
sop $>=$ zeroSoP
checkPrimExp pe = pure False

runRefineM :: VNameSource -> RefineM a -> (a, AlgEnv VName, VNameSource)
runRefineM :: VNameSource -> RefineM a -> (a, AlgEnv VName Exp, VNameSource)
runRefineM src (RefineM m) =
let ((a, algenv), src', _) = runRWS (runSoPMT_ m) mempty src
in (a, algenv, src')
Expand Down
233 changes: 233 additions & 0 deletions src/Futhark/SoP/Convert.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}

-- | Translating to-and-from PrimExp to the sum-of-product representation.
module Futhark.SoP.Convert
( FromSoP (..),
ToSoP (..),
)
where

import Control.Monad.State
import Data.List (find)
import Data.Set (Set)
import Data.Set qualified as S
import Futhark.Analysis.PrimExp (PrimExp, PrimType, (~*~), (~+~), (~-~), (~/~), (~==~))
import Futhark.Analysis.PrimExp qualified as PE
import Futhark.SoP.Monad
import Futhark.SoP.SoP
import Futhark.SoP.Util
import Futhark.Util.Pretty
import Language.Futhark.Core
import Language.Futhark.Prop
import Language.Futhark.Syntax (VName)
import Language.Futhark.Syntax qualified as E

-- | Conversion from 'SoP's to other representations.
class FromSoP u a where
fromSoP :: SoP u -> a

instance Ord u => FromSoP u (PrimExp u) where
fromSoP sop =
foldr ((~+~) . fromTerm) (PE.ValueExp $ PE.IntValue $ PE.intValue PE.Int64 (0 :: Integer)) (sopToLists sop)
where
fromTerm (term, n) =
foldl (~*~) (PE.ValueExp $ PE.IntValue $ PE.intValue PE.Int64 n) $
map fromSym term
fromSym sym = PE.LeafExp sym $ PE.IntType PE.Int64

-- instance FromSoP VName Exp where
-- fromSoP sop = undefined
-- where
-- -- foldr ((~+~) . fromTerm) (PE.ValueExp $ PE.IntValue $ PE.intValue PE.Int64 (0 :: Integer)) (sopToLists sop)
-- mult = (E.AppExp (E.Var (E.QualName [] (VName "*" 0)) (E.Info $ i64) mempty) (E.Info $ E.AppRes i64 []))
-- fromTerm (term, n) =
-- foldl mult (E.Literal $ E.SignedValue $ PE.intValue PE.Int64 n) $
-- map fromSym term
-- fromSym sym = E.Var (E.QualName [] sym) (E.Info i64) mempty
-- i64 = E.Scalar $ E.Prim $ E.Signed $ PE.Int64

-- | Conversion from some expressions to
-- 'SoP's. Monadic because it may involve look-ups in the
-- untranslatable expression environment.
--
-- Separating into two functions is to make clearer the fact that
-- 'toSoPCmp' returns SoPs @sop@ implicitly in the relation @sop >=
-- 0@. Maybe this should be enforced at the constructor level
-- instead; i.e. have constructors for numeric SoPs and SoPs in
-- relations.
class ToSoP u e where
toSoPNum :: MonadSoP u e m => e -> m (Integer, SoP u)

-- | Translates a 'PrimExp' containing a (top-level) comparison
-- operator into a 'SoP' representation such that @sop >= 0@.
toSoPCmp :: MonadSoP u e m => e -> m (Integer, SoP u >= 0)

instance (Nameable u, Ord u, Show u, Pretty u) => ToSoP u (PrimExp u) where
toSoPNum primExp = do
(f, sop) <- toSoPNum' 1 primExp
pure (abs f, signum f `scaleSoP` sop)
where
notIntType :: PrimType -> Bool
notIntType (PE.IntType _) = False
notIntType _ = True

divideIsh :: PE.BinOp -> Bool
divideIsh (PE.UDiv _ _) = True
divideIsh (PE.UDivUp _ _) = True
divideIsh (PE.SDiv _ _) = True
divideIsh (PE.SDivUp _ _) = True
divideIsh (PE.FDiv _) = True
divideIsh _ = False
toSoPNum' _ pe
| notIntType (PE.primExpType pe) =
error "toSoPNum' applied to a PrimExp whose prim type is not Integer"
toSoPNum' f (PE.LeafExp vnm _) =
pure (f, sym2SoP vnm)
toSoPNum' f (PE.ValueExp (PE.IntValue iv)) =
pure (1, int2SoP $ getIntVal iv `div` f)
where
getIntVal :: PE.IntValue -> Integer
getIntVal (PE.Int8Value v) = fromIntegral v
getIntVal (PE.Int16Value v) = fromIntegral v
getIntVal (PE.Int32Value v) = fromIntegral v
getIntVal (PE.Int64Value v) = fromIntegral v
toSoPNum' f (PE.UnOpExp PE.Complement {} x) = do
(f', x_sop) <- toSoPNum' f x
pure (f', negSoP x_sop)
toSoPNum' f (PE.BinOpExp PE.Add {} x y) = do
(x_f, x_sop) <- toSoPNum x
(y_f, y_sop) <- toSoPNum y
let l_c_m = lcm x_f y_f
(x_m, y_m) = (l_c_m `div` x_f, l_c_m `div` y_f)
x_sop' = mulSoPs (int2SoP x_m) x_sop
y_sop' = mulSoPs (int2SoP y_m) y_sop
pure (f * l_c_m, addSoPs x_sop' y_sop')
toSoPNum' f (PE.BinOpExp PE.Sub {} x y) = do
(x_f, x_sop) <- toSoPNum x
(y_f, y_sop) <- toSoPNum y
let l_c_m = lcm x_f y_f
(x_m, y_m) = (l_c_m `div` x_f, l_c_m `div` y_f)
x_sop' = mulSoPs (int2SoP x_m) x_sop
n_y_sop' = mulSoPs (int2SoP (-y_m)) y_sop
pure (f * l_c_m, addSoPs x_sop' n_y_sop')
toSoPNum' f pe@(PE.BinOpExp PE.Mul {} x y) = do
(x_f, x_sop) <- toSoPNum x
(y_f, y_sop) <- toSoPNum y
case (x_f, y_f) of
(1, 1) -> pure (f, mulSoPs x_sop y_sop)
_ -> do
x' <- lookupUntransPE pe
toSoPNum' f $ PE.LeafExp x' $ PE.primExpType pe
-- pe / 1 == pe
toSoPNum' f (PE.BinOpExp divish pe q)
| divideIsh divish && PE.oneIshExp q =
toSoPNum' f pe
-- evaluate `val_x / val_y`
toSoPNum' f (PE.BinOpExp divish x y)
| divideIsh divish,
PE.ValueExp v_x <- x,
PE.ValueExp v_y <- y = do
let f' = v_x `vdiv` v_y
toSoPNum' f $ PE.ValueExp f'
-- Trivial simplifications:
-- (y * v) / y = v and (u * y) / y = u
| divideIsh divish,
PE.BinOpExp (PE.Mul _ _) u v <- x,
(is_fst, is_snd) <- (u == y, v == y),
is_fst || is_snd = do
toSoPNum' f $ if is_fst then v else u
where
vdiv (PE.IntValue (PE.Int64Value x')) (PE.IntValue (PE.Int64Value y')) =
PE.IntValue $ PE.Int64Value (x' `div` y')
vdiv (PE.IntValue (PE.Int32Value x')) (PE.IntValue (PE.Int32Value y')) =
PE.IntValue $ PE.Int32Value (x' `div` y')
vdiv (PE.IntValue (PE.Int16Value x')) (PE.IntValue (PE.Int16Value y')) =
PE.IntValue $ PE.Int16Value (x' `div` y')
vdiv (PE.IntValue (PE.Int8Value x')) (PE.IntValue (PE.Int8Value y')) =
PE.IntValue $ PE.Int8Value (x' `div` y')
-- vdiv (FloatValue (Float32Value x)) (FloatValue (Float32Value y)) =
-- FloatValue $ Float32Value $ x / y
-- vdiv (FloatValue (Float64Value x)) (FloatValue (Float64Value y)) =
-- FloatValue $ Float64Value $ x / y
vdiv _ _ = error "In vdiv: illegal type for division!"
-- try heuristic for exact division
toSoPNum' f pe@(PE.BinOpExp divish x y)
| divideIsh divish = do
(x_f, x_sop) <- toSoPNum x
(y_f, y_sop) <- toSoPNum y
case (x_f, y_f, divSoPs x_sop y_sop) of
(1, 1, Just res) -> pure (f, res)
_ -> do
x' <- lookupUntransPE pe
toSoPNum' f $ PE.LeafExp x' $ PE.primExpType pe
-- Anything that is not handled by specific cases of toSoPNum'
-- is handled by this default procedure:
-- If the target `pe` is in the unknwon `env`
-- Then return thecorresponding binding
-- Else make a fresh symbol `v`, bind it in the environment
-- and return it.
toSoPNum' f pe = do
x <- lookupUntransPE pe
toSoPNum' f $ PE.LeafExp x $ PE.primExpType pe

toSoPCmp (PE.CmpOpExp (PE.CmpEq ptp) x y)
-- x = y => x - y = 0
| PE.IntType {} <- ptp = toSoPNum $ x ~-~ y
toSoPCmp (PE.CmpOpExp lessop x y)
-- x < y => x + 1 <= y => y >= x + 1 => y - (x+1) >= 0
| Just itp <- lthishType lessop =
toSoPNum $ y ~-~ (x ~+~ PE.ValueExp (PE.IntValue $ PE.intValue itp (1 :: Integer)))
-- x <= y => y >= x => y - x >= 0
| Just _ <- leqishType lessop =
toSoPNum $ y ~-~ x
where
lthishType (PE.CmpSlt itp) = Just itp
lthishType (PE.CmpUlt itp) = Just itp
lthishType _ = Nothing
leqishType (PE.CmpUle itp) = Just itp
leqishType (PE.CmpSle itp) = Just itp
leqishType _ = Nothing
toSoPCmp pe = error $ "toSoPCmp: not a comparison " <> prettyString pe

instance (Nameable u, Ord u, Show u, Pretty u) => ToSoP u Exp where
toSoPNum (E.Literal v _) =
(pure . (1,)) $
case v of
E.SignedValue x -> int2SoP $ PE.valueIntegral x
E.UnsignedValue x -> int2SoP $ PE.valueIntegral x
_ -> error ""
toSoPNum e = do
x <- lookupUntransPE e
pure (1, sym2SoP x)

-- expToPrimExp (IntLit v (Info t) _) =

toSoPCmp (E.AppExp (E.BinOp (op, _) _ (e_x, _) (e_y, _) _) _)
| E.baseTag (E.qualLeaf op) <= maxIntrinsicTag,
name <- E.baseString $ E.qualLeaf op,
Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = do
(_, x) <- toSoPNum e_x
(_, y) <- toSoPNum e_y
(1,)
<$> case bop of
E.Equal -> pure $ x .-. y
E.Less -> pure $ y .-. (x .+. int2SoP 1)
E.Leq -> pure $ y .-. x
E.Greater -> pure $ x .-. (y .+. int2SoP 1)
E.Geq -> pure $ x .-. y

--
-- {--
---- This is a more refined treatment, but probably
---- an overkill (harmful if you get the type wrong)
-- fromSym unknowns sym
-- | Nothing <- M.lookup sym (dir unknowns) =
-- LeafExp sym $ IntType Integer
-- | Just pe1 <- M.lookup sym (dir unknowns),
-- IntType Integer <- PE.primExpType pe1 =
-- pe1
-- fromSym unknowns sym =
-- error ("Type error in fromSym: type of " ++
-- show sym ++ " is not Integer")
----}
Loading

0 comments on commit 088201e

Please sign in to comment.