Skip to content

Commit

Permalink
Well-formedness (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtristan authored Sep 19, 2024
1 parent 940d59d commit 9355881
Show file tree
Hide file tree
Showing 6 changed files with 477 additions and 64 deletions.
10 changes: 8 additions & 2 deletions Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ def main (args : List String) : IO UInt32 := do
let content := StableHLO.Parsing.parse content
IO.print s!"Parsing {file}... "
match content with
| .ok _ =>
| .ok p =>
passed := file :: passed
let fpReport : FilePath := System.mkFilePath ["Tests", file ++ ".report"]
for msg in p.2.report do
writeFile fpReport s!"File {file}, {msg}\n"
IO.println "success"
| .error _ =>
failed := file :: failed
Expand All @@ -42,7 +45,10 @@ def main (args : List String) : IO UInt32 := do
let content := StableHLO.Parsing.parse content
match content with
| .ok p =>
IO.println s!"{repr p}"
let fpAST : FilePath := System.mkFilePath ["Tests", file ++ ".ast"]
let fpReport : FilePath := System.mkFilePath ["Tests", file ++ ".report"]
writeFile fpAST s!"{repr p.1}\n"
writeFile fpReport s!"{p.2.report}\n"
return 0
| .error e =>
IO.println s!"{e.2.2}"
Expand Down
117 changes: 117 additions & 0 deletions SHerLOC/AST1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,115 @@ structure FuncInput where
typ : ValueType
deriving Repr, Inhabited, Nonempty

inductive OpCode where
| abs
| add
| afterAll
| allGather
| allReduce
| allToAll
| and
| atan2
| batchNormGrad
| batchNormInference
| batchNormTraining
| bitcastConvert
| broadcastInDim
| case
| cbrt
| ceil
| cholesky
| clamp
| collectiveBroadcast
| collectivePermute
| compare
| complex
| composite
| concatenate
| constant
| convert
| convolution
| cosine
| countLeadingZeros
| customCall
| divide
| dotGeneral
| dynamicBroadcastInDim
| dynamicConv
| dynamicGather
| dynamicIota
| dynamicPad
| dynamicReshape
| dynamicSlice
| dynamicUpdateSlice
| exponential
| exponentialMinusOne
| fft
| floor
| gather
| getDimensionSize
| getTupleElement
| if
| imag
| infeed
| iota
| isFinite
| log
| logPlusOne
| logistic
| map
| maximum
| minimum
| multiply
| negate
| not
| optimizationBarrier
| or
| outfeed
| pad
| partitionId
| popcnt
| power
| real
| realDynamicSlice
| recv
| reduce
| reducePrecision
| reduceScatter
| reduceWindow
| remainder
| replicaId
| reshape
| reverse
| rng
| rngBitGenerator
| roundNearestAfz
| roundNearestEven
| rsqrt
| scatter
| select
| selectAndScatter
| send
| shiftLeft
| shiftRightArithmetic
| shiftRightLogical
| sign
| sine
| slice
| sort
| sqrt
| subtract
| tan
| tanh
| transpose
| triangularSolve
| tuple
| uniformDequantize
| uniformQuantize
| while
| xor
deriving Repr, Inhabited, Nonempty

mutual

inductive InputFunc where
Expand All @@ -304,6 +413,14 @@ mutual

inductive Operation where
| stablehlo
(opCode : OpCode)
(inputValues : List ValueId)
(inputFunctions : List InputFunc)
(inputAttributes : List Attribute)
(outputs : List ValueId)
(signature : FunctionType)
| tanh (operand : ValueId) (typ : FunctionType)
| other
(name : String)
(inputValues : List ValueId)
(inputFunctions : List InputFunc)
Expand Down
6 changes: 6 additions & 0 deletions SHerLOC/Parsing/Intermediate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ mutual
if ← isChar '"' then
return Literal.string <| ← parseStringLiteral
if ← isChar 'a' then
report "literal array"
return Literal.array <| ← parseArrayLiteral

if ← isParse "#stablehlo" then {
if (← isParse ".") then {
report "literal record"
if ← isParse "conv" then return Literal.convolution <| ← parseConvolution
if ← isParse "dot_algorithm" then return Literal.stableHLORecord <| ← parseRecord
if ← isParse "dot" then return Literal.stableHLORecord <| ← parseRecord
Expand All @@ -69,12 +71,15 @@ mutual
}

if ← isChar '[' then
report "literal list"
return Literal.list <| ← parseList "[" "]" "," parseLiteral

if ← isChar '{' then
report "literal attribute"
return Literal.dictionary <| ← parseAttributes

if ← isChar '@' then
report "literal function"
return Literal.func <| ← parseFuncId

throw <| (← error "literal")
Expand All @@ -92,6 +97,7 @@ mutual
partial def parseAttribute : PState Attribute := do
push "parseAttribute"
if ← isParse "use_global_device_ids" then
report "literal use_global_device_ids"
pop "parseAttribute"
return Attribute.mk "use_global_device_ids" <| Constant.mk (Literal.element (ElementLiteral.booleanLiteral BooleanLiteral.true)) none
else
Expand Down
Loading

0 comments on commit 9355881

Please sign in to comment.