Skip to content

Commit

Permalink
Add initial constant folding to the compiler pipeline (#800)
Browse files Browse the repository at this point in the history
* Remove naive static simplification implementation

* Add operations to list.mc

* Type fixes in eval

* New constant fold and constant propagation implementation

* Make tests more stable

* Add constant folding and constant propagation to compiler pipeline

* Test with compiler optimizations
  • Loading branch information
br4sco authored Jan 3, 2024
1 parent d0f596b commit 4030b35
Show file tree
Hide file tree
Showing 9 changed files with 699 additions and 667 deletions.
10 changes: 10 additions & 0 deletions src/main/compile.mc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include "mexpr/shallow-patterns.mc"
include "mexpr/symbolize.mc"
include "mexpr/type-check.mc"
include "mexpr/utest-generate.mc"
include "mexpr/constant-fold.mc"
include "ocaml/ast.mc"
include "ocaml/external-includes.mc"
include "ocaml/mcore.mc"
Expand All @@ -35,6 +36,7 @@ lang MCoreCompile =
MExprUtestGenerate + MExprRuntimeCheck + MExprProfileInstrument +
MExprPrettyPrint +
MExprLowerNestedPatterns +
MExprConstantFold +
OCamlTryWithWrap + MCoreCompileLang + PhaseStats +
SpecializeCompile +
PprintTyAnnot + HtmlAnnotator
Expand Down Expand Up @@ -92,6 +94,14 @@ let compileWithUtests = lam options : Options. lam sourcePath. lam ast.
let ast = generateUtest options.runTests ast in
endPhaseStats log "generate-utest" ast;

let ast =
if and (options.enableConstantFold) (not options.disableOptimizations)
then constantFold ast else ast
in
endPhaseStats log "constant folding" ast;
(if options.debugConstantFold then
printLn (expr2str ast) else ());

let ast = lowerAll ast in
endPhaseStats log "pattern-lowering" ast;
(if options.debugShallow then
Expand Down
8 changes: 8 additions & 0 deletions src/main/options-config.mc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ let optionsConfig : ParseConfig Options = [
"Print the AST after lowering nested patterns to shallow ones",
lam p: ArgPart Options.
let o: Options = p.options in {o with debugShallow = true}),
([("--debug-constant-fold", "", "")],
"Print the AST after constant folding and constant propagation",
lam p: ArgPart Options.
let o: Options = p.options in {o with debugConstantFold = true}),
([("--debug-phases", "", "")],
"Show debug and profiling information about each pass",
lam p: ArgPart Options.
Expand Down Expand Up @@ -60,6 +64,10 @@ let optionsConfig : ParseConfig Options = [
"Disables optimizations to decrease compilation time",
lam p: ArgPart Options.
let o: Options = p.options in {o with disableOptimizations = true}),
([("--enable-constant-fold", "", "")],
"Enables constant folding and constant propagation",
lam p: ArgPart Options.
let o: Options = p.options in {o with enableConstantFold = true}),
([("--tuned", "", "")],
"Use tuned values when compiling, or as defaults when tuning",
lam p: ArgPart Options.
Expand Down
2 changes: 2 additions & 0 deletions src/main/options-type.mc
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ type Options = {
debugTypeCheck : Bool,
debugProfile : Bool,
debugShallow : Bool,
debugConstantFold : Bool,
debugPhases : Bool,
exitBefore : Bool,
disablePruneExternalUtests : Bool,
disablePruneExternalUtestsWarning : Bool,
runTests : Bool,
runtimeChecks : Bool,
disableOptimizations : Bool,
enableConstantFold : Bool,
useTuned : Bool,
compileAfterTune : Bool,
accelerate : Bool,
Expand Down
2 changes: 2 additions & 0 deletions src/main/options.mc
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ let optionsDefault : Options = {
debugTypeCheck = false,
debugProfile = false,
debugShallow = false,
debugConstantFold = false,
debugPhases = false,
exitBefore = false,
disablePruneExternalUtests = false,
disablePruneExternalUtestsWarning = false,
runTests = false,
runtimeChecks = false,
disableOptimizations = false,
enableConstantFold = false,
useTuned = false,
compileAfterTune = false,
accelerate = false,
Expand Down
52 changes: 27 additions & 25 deletions stdlib/ad/dualnum-lift.mc
Original file line number Diff line number Diff line change
Expand Up @@ -362,30 +362,32 @@ utest geqn _num2 (_dnum _e2 _dnum112 _num3) with true
-- ARITHMETIC OPERATORS --
---------------------------

let _eqfApprox = eqfApprox 1.e-6

-- lifted addition
utest addn _num1 _num2 with _num3 using dualnumEq eqf
utest addn _dnum010 _num2 with _dnum0 _num3 _num0 using dualnumEq eqf
utest addn _dnum011 _num2 with _dnum031 using dualnumEq eqf
utest addn _dnum011 _dnum011 with _dnum022 using dualnumEq eqf
utest addn _dnum011 _dnum111 with _dnum1 _dnum021 _num1 using dualnumEq eqf
utest addn _num1 _num2 with _num3 using dualnumEq _eqfApprox
utest addn _dnum010 _num2 with _dnum0 _num3 _num0 using dualnumEq _eqfApprox
utest addn _dnum011 _num2 with _dnum031 using dualnumEq _eqfApprox
utest addn _dnum011 _dnum011 with _dnum022 using dualnumEq _eqfApprox
utest addn _dnum011 _dnum111 with _dnum1 _dnum021 _num1 using dualnumEq _eqfApprox

-- lifted multiplication
utest muln _num1 _num2 with _num2 using dualnumEq eqf
utest muln _dnum010 _num2 with _dnum0 _num2 _num0 using dualnumEq eqf
utest muln _dnum011 _num2 with _dnum022 using dualnumEq eqf
utest muln _dnum012 _dnum034 with _dnum0 _num3 _num10 using dualnumEq eqf
utest muln _dnum012 _dnum134 with _dnum1 _dnum036 _dnum048 using dualnumEq eqf
utest muln _num1 _num2 with _num2 using dualnumEq _eqfApprox
utest muln _dnum010 _num2 with _dnum0 _num2 _num0 using dualnumEq _eqfApprox
utest muln _dnum011 _num2 with _dnum022 using dualnumEq _eqfApprox
utest muln _dnum012 _dnum034 with _dnum0 _num3 _num10 using dualnumEq _eqfApprox
utest muln _dnum012 _dnum134 with _dnum1 _dnum036 _dnum048 using dualnumEq _eqfApprox

-- lifted negation
let negn = lam p. _lift1 negf (lam. Primal (negf 1.)) p

utest negn _num1 with Primal (negf 1.) using dualnumEq eqf
utest negn _num0 with Primal (negf 0.) using dualnumEq eqf
utest negn _dnum010 with _dnum0 (Primal (negf 1.)) _num0 using dualnumEq eqf
utest negn _num1 with Primal (negf 1.) using dualnumEq _eqfApprox
utest negn _num0 with Primal (negf 0.) using dualnumEq _eqfApprox
utest negn _dnum010 with _dnum0 (Primal (negf 1.)) _num0 using dualnumEq _eqfApprox
utest negn _dnum012 with _dnum0 (Primal (negf 1.)) (Primal (negf 2.))
using dualnumEq eqf
using dualnumEq _eqfApprox

utest der negn _num1 with negn _num1 using dualnumEq eqf
utest der negn _num1 with negn _num1 using dualnumEq _eqfApprox

-- lifted subtraction
let subn = lam p1. lam p2.
Expand All @@ -395,15 +397,15 @@ let subn = lam p1. lam p2.
(lam. lam. negn (Primal 1.))
p1 p2

utest subn _num2 _num1 with _num1 using dualnumEq eqf
utest subn _dnum020 _num1 with _dnum0 _num1 _num0 using dualnumEq eqf
utest subn _dnum021 _num1 with _dnum011 using dualnumEq eqf
utest subn _dnum022 _dnum011 with _dnum011 using dualnumEq eqf
utest subn _num2 _num1 with _num1 using dualnumEq _eqfApprox
utest subn _dnum020 _num1 with _dnum0 _num1 _num0 using dualnumEq _eqfApprox
utest subn _dnum021 _num1 with _dnum011 using dualnumEq _eqfApprox
utest subn _dnum022 _dnum011 with _dnum011 using dualnumEq _eqfApprox

utest
let r = subn _dnum122 _dnum011 in
dualnumPrimal _e1 r
with _dnum0 _num1 (Primal (negf 1.)) using dualnumEq eqf
with _dnum0 _num1 (Primal (negf 1.)) using dualnumEq _eqfApprox


-- lifted abs
Expand All @@ -425,16 +427,16 @@ recursive
p1 p2
end

utest divn _num4 _num2 with _num2 using dualnumEq eqf
utest divn _dnum040 _num2 with _dnum0 _num2 _num0 using dualnumEq eqf
utest divn _dnum044 _num2 with _dnum022 using dualnumEq eqf
utest divn _num4 _num2 with _num2 using dualnumEq _eqfApprox
utest divn _dnum040 _num2 with _dnum0 _num2 _num0 using dualnumEq _eqfApprox
utest divn _dnum044 _num2 with _dnum022 using dualnumEq _eqfApprox

utest divn _dnum012 _dnum034
with _dnum0 (Primal (divf 1. 3.)) (Primal (divf 2. 9.)) using dualnumEq eqf
with _dnum0 (Primal (divf 1. 3.)) (Primal (divf 2. 9.)) using dualnumEq _eqfApprox

utest divn _dnum012 _dnum134
with _dnum1 (_dnum0 (Primal (divf 1. 3.))
(Primal (divf 2. 3.)))
(_dnum0 (Primal (divf (negf 4.) 9.))
(Primal (divf (negf 8.) 9.)))
using dualnumEq eqf
using dualnumEq _eqfApprox
37 changes: 37 additions & 0 deletions stdlib/list.mc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ let listEq : all a. all b. (a -> b -> Bool) -> List a -> List b -> Bool =
end
in work

let listAll : all a. (a -> Bool) -> List a -> Bool = lam p. lam li.
recursive let forAll = lam li.
switch li
case Cons (x, li) then
if p x then forAll li else false
case Nil _ then true
end
in
forAll li

let listFilter : all a. (a -> Bool) -> List a -> List a = lam p. lam li.
listReverse
(listFoldl (lam acc. lam x. if p x then Cons (x, acc) else acc) (Nil ()) li)

let listConcat : all a. List a -> List a -> List a = lam lhs. lam rhs.
listFoldl (lam acc. lam x. listCons x acc) rhs (listReverse lhs)

mexpr

let l1 = listEmpty in
Expand Down Expand Up @@ -119,4 +136,24 @@ utest l6 with Cons (3, Cons (4, Cons (5, Nil ()))) in
utest listFoldl addi 0 l4 with 9 in
utest listFoldl addi 0 l6 with 12 in

utest listAll (lti 2) (listFromSeq [4, 4, 5, 3]) with true in
utest listAll (gti 3) (listFromSeq [4, 4, 5, 3]) with false in

utest listFilter (lti 2) (listFromSeq [4, 3, 5, 3])
with listFromSeq [4, 3, 5, 3]
in
utest listFilter (lti 3) (listFromSeq [4, 3, 5, 3])
with listFromSeq [4, 5]
in

utest listConcat (listFromSeq [1, 2, 3]) (listFromSeq [4, 5])
with listFromSeq [1, 2, 3, 4, 5]
in
utest listConcat (listFromSeq []) (listFromSeq [4, 5])
with listFromSeq [4, 5]
in
utest let l : [Int] = [] in listConcat (listFromSeq l) (listFromSeq l)
with listFromSeq []
in

()
Loading

0 comments on commit 4030b35

Please sign in to comment.