diff --git a/SHerLOC/AST1.lean b/SHerLOC/AST1.lean index ea9741a..6cf38ca 100644 --- a/SHerLOC/AST1.lean +++ b/SHerLOC/AST1.lean @@ -145,16 +145,6 @@ inductive ArrayLiteral where | array1 (literal : List BooleanLiteral) deriving Repr, Inhabited, Nonempty -inductive StableHLORecordFieldValue where - | one (literal : Nat) - | many (literal : List Nat) - deriving Repr, Inhabited, Nonempty - -structure StableHLORecordField where - name : String - value : StableHLORecordFieldValue - deriving Repr, Inhabited, Nonempty - inductive ConvolutionMode where | i | o @@ -265,6 +255,17 @@ inductive SType where mutual + inductive StableHLORecordFieldValue where + | one (literal : Nat) + | many (literal : List Nat) + | type (literal : FloatType) + | bool (literal : Bool) + deriving Repr, Inhabited, Nonempty + + inductive StableHLORecordField where + | mk (name : String) (value : StableHLORecordFieldValue) + deriving Repr, Inhabited, Nonempty + inductive Literal where | enum (literal : EnumLiteral) | element (literal : ElementLiteral) @@ -276,7 +277,6 @@ mutual | list (literal : List Literal) | dictionary (literal : List Attribute) | array (literal : ArrayLiteral) - deriving Repr, Inhabited, Nonempty inductive Constant where diff --git a/SHerLOC/Parsing/Intermediate.lean b/SHerLOC/Parsing/Intermediate.lean index f4abc76..c5f1ac8 100644 --- a/SHerLOC/Parsing/Intermediate.lean +++ b/SHerLOC/Parsing/Intermediate.lean @@ -11,6 +11,35 @@ import SHerLOC.Parsing.Constants namespace StableHLO.Parsing +def parseStableHLORecordFieldValue : PState (StableHLORecordFieldValue) := do + if (← is "[") then + let value ← parseDecimals + return StableHLORecordFieldValue.many value + else if (← isDigit) then + let value ← parseDecimal + return StableHLORecordFieldValue.one value + else if (← isParse "true") then + return StableHLORecordFieldValue.bool true + else if (← isParse "false") then + return StableHLORecordFieldValue.bool false + else + let type ← parseFloatType + return StableHLORecordFieldValue.type type + +def parseStableHLORecordField : PState (StableHLORecordField) := do + push "parseStableHLORecordField" + let name ← parseId + parseItem "=" + let value ← parseStableHLORecordFieldValue + pop "parseStableHLORecordField" + return StableHLORecordField.mk name value + +def parseRecord : PState (List StableHLORecordField) := do + push "parseRecord" + let r ← parseList "<" ">" "," parseStableHLORecordField + pop "parseRecord" + return r + mutual partial def parseLiteral : PState Literal := do @@ -19,7 +48,7 @@ mutual return Literal.element <| ElementLiteral.floatLiteral <| ← parseFloatLiteral if ← isChar 'd' then return Literal.tensor <| ← parseTensorLiteral - if (← isChar 't') || (← isChar 'f') then + if (← is "tr") || (← is "fa") then return Literal.element <| ElementLiteral.booleanLiteral <| ← parseBooleanLiteral if (← isChar '(') then return Literal.element <| ElementLiteral.complexLiteral <| ← parseComplexLiteral @@ -48,7 +77,7 @@ mutual if ← isChar '@' then return Literal.func <| ← parseFuncId - throw <| ← error "literal" + throw <| (← error "literal") partial def parseConstant : PState Constant := do push "parseConstant" diff --git a/SHerLOC/Parsing/Numbers.lean b/SHerLOC/Parsing/Numbers.lean index f6847c0..7b892ce 100644 --- a/SHerLOC/Parsing/Numbers.lean +++ b/SHerLOC/Parsing/Numbers.lean @@ -248,26 +248,6 @@ def parseArrayLiteral : PState ArrayLiteral := do return ArrayLiteral.array1 r throw <| ← error "array literal" -def parseStableHLORecordFieldValue : PState (StableHLORecordFieldValue) := do - if (← is "[") then - let value ← parseDecimals - return StableHLORecordFieldValue.many value - else - let value ← parseDecimal - return StableHLORecordFieldValue.one value - -def parseStableHLORecordField : PState (StableHLORecordField) := do - let name ← parseId - parseItem "=" - let value ← parseStableHLORecordFieldValue - return { name, value} - -def parseRecord : PState (List StableHLORecordField) := do - push "parseRecord" - let r ← parseList "<" ">" "," parseStableHLORecordField - pop "parseRecord" - return r - def parseConvolutionMode : PState ConvolutionMode := do push "parseConvolutionMode" let mut r := none diff --git a/SHerLOC/Parsing/Types.lean b/SHerLOC/Parsing/Types.lean index 60c0c9d..b319097 100644 --- a/SHerLOC/Parsing/Types.lean +++ b/SHerLOC/Parsing/Types.lean @@ -50,6 +50,7 @@ def tryParseFloatType : PState (Option FloatType) := do if ← isParse "f8E5M2" then r := some FloatType.f8E5M2 } if ← isParse "bf16" then r := some FloatType.bf16 + if ← isParse "tf32" then r := some FloatType.tf32 pop "tryParseFloatType" return r