Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Well-formedness #12

Merged
merged 23 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ def main (args : List String) : IO UInt32 := do
let content := StableHLO.Parsing.parse content
IO.print s!"Parsing {file}... "
match content with
| .ok _ =>
| .ok p =>
passed := file :: passed
let fpReport : FilePath := System.mkFilePath ["Tests", file ++ ".report"]
for msg in p.2.report do
writeFile fpReport s!"File {file}, {msg}\n"
IO.println "success"
| .error _ =>
failed := file :: failed
Expand All @@ -42,7 +45,10 @@ def main (args : List String) : IO UInt32 := do
let content := StableHLO.Parsing.parse content
match content with
| .ok p =>
IO.println s!"{repr p}"
let fpAST : FilePath := System.mkFilePath ["Tests", file ++ ".ast"]
let fpReport : FilePath := System.mkFilePath ["Tests", file ++ ".report"]
writeFile fpAST s!"{repr p.1}\n"
writeFile fpReport s!"{p.2.report}\n"
return 0
| .error e =>
IO.println s!"{e.2.2}"
Expand Down
117 changes: 117 additions & 0 deletions SHerLOC/AST1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,115 @@ structure FuncInput where
typ : ValueType
deriving Repr, Inhabited, Nonempty

inductive OpCode where
| abs
| add
| afterAll
| allGather
| allReduce
| allToAll
| and
| atan2
| batchNormGrad
| batchNormInference
| batchNormTraining
| bitcastConvert
| broadcastInDim
| case
| cbrt
| ceil
| cholesky
| clamp
| collectiveBroadcast
| collectivePermute
| compare
| complex
| composite
| concatenate
| constant
| convert
| convolution
| cosine
| countLeadingZeros
| customCall
| divide
| dotGeneral
| dynamicBroadcastInDim
| dynamicConv
| dynamicGather
| dynamicIota
| dynamicPad
| dynamicReshape
| dynamicSlice
| dynamicUpdateSlice
| exponential
| exponentialMinusOne
| fft
| floor
| gather
| getDimensionSize
| getTupleElement
| if
| imag
| infeed
| iota
| isFinite
| log
| logPlusOne
| logistic
| map
| maximum
| minimum
| multiply
| negate
| not
| optimizationBarrier
| or
| outfeed
| pad
| partitionId
| popcnt
| power
| real
| realDynamicSlice
| recv
| reduce
| reducePrecision
| reduceScatter
| reduceWindow
| remainder
| replicaId
| reshape
| reverse
| rng
| rngBitGenerator
| roundNearestAfz
| roundNearestEven
| rsqrt
| scatter
| select
| selectAndScatter
| send
| shiftLeft
| shiftRightArithmetic
| shiftRightLogical
| sign
| sine
| slice
| sort
| sqrt
| subtract
| tan
| tanh
| transpose
| triangularSolve
| tuple
| uniformDequantize
| uniformQuantize
| while
| xor
deriving Repr, Inhabited, Nonempty

mutual

inductive InputFunc where
Expand All @@ -304,6 +413,14 @@ mutual

inductive Operation where
| stablehlo
(opCode : OpCode)
(inputValues : List ValueId)
(inputFunctions : List InputFunc)
(inputAttributes : List Attribute)
(outputs : List ValueId)
(signature : FunctionType)
| tanh (operand : ValueId) (typ : FunctionType)
| other
(name : String)
(inputValues : List ValueId)
(inputFunctions : List InputFunc)
Expand Down
6 changes: 6 additions & 0 deletions SHerLOC/Parsing/Intermediate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ mutual
if ← isChar '"' then
return Literal.string <| ← parseStringLiteral
if ← isChar 'a' then
report "literal array"
return Literal.array <| ← parseArrayLiteral

if ← isParse "#stablehlo" then {
if (← isParse ".") then {
report "literal record"
if ← isParse "conv" then return Literal.convolution <| ← parseConvolution
if ← isParse "dot_algorithm" then return Literal.stableHLORecord <| ← parseRecord
if ← isParse "dot" then return Literal.stableHLORecord <| ← parseRecord
Expand All @@ -69,12 +71,15 @@ mutual
}

if ← isChar '[' then
report "literal list"
return Literal.list <| ← parseList "[" "]" "," parseLiteral

if ← isChar '{' then
report "literal attribute"
return Literal.dictionary <| ← parseAttributes

if ← isChar '@' then
report "literal function"
return Literal.func <| ← parseFuncId

throw <| (← error "literal")
Expand All @@ -92,6 +97,7 @@ mutual
partial def parseAttribute : PState Attribute := do
push "parseAttribute"
if ← isParse "use_global_device_ids" then
report "literal use_global_device_ids"
pop "parseAttribute"
return Attribute.mk "use_global_device_ids" <| Constant.mk (Literal.element (ElementLiteral.booleanLiteral BooleanLiteral.true)) none
else
Expand Down
Loading