diff --git a/Main.lean b/Main.lean index 8108943..3ede04d 100644 --- a/Main.lean +++ b/Main.lean @@ -40,13 +40,13 @@ def main (args : List String) : IO UInt32 := do return 0 else if args.length = 1 then let file := args[0]! - let fp : FilePath := System.mkFilePath ["Tests", file] + let fp : FilePath := System.mkFilePath [file] let content ← readFile fp let content := StableHLO.Parsing.parse content match content with | .ok p => - let fpAST : FilePath := System.mkFilePath ["Tests", file ++ ".ast"] - let fpReport : FilePath := System.mkFilePath ["Tests", file ++ ".report"] + let fpAST : FilePath := System.mkFilePath [file ++ ".ast"] + let fpReport : FilePath := System.mkFilePath [file ++ ".report"] writeFile fpAST s!"{repr p.1}\n" writeFile fpReport s!"{p.2.report}\n" return 0 diff --git a/SHerLOC/Parsing/Constants.lean b/SHerLOC/Parsing/Constants.lean deleted file mode 100644 index 0811996..0000000 --- a/SHerLOC/Parsing/Constants.lean +++ /dev/null @@ -1,12 +0,0 @@ -/- -Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Jean-Baptiste Tristan --/ -import SHerLOC.AST1 -import SHerLOC.Parsing.Parser -import SHerLOC.Parsing.Types - -namespace StableHLO.Parsing - -end StableHLO.Parsing diff --git a/SHerLOC/Parsing/Functions.lean b/SHerLOC/Parsing/Functions.lean index 4e32f7e..2005014 100644 --- a/SHerLOC/Parsing/Functions.lean +++ b/SHerLOC/Parsing/Functions.lean @@ -91,7 +91,6 @@ def parseFunctionDictionaryAttributes : PState (String × FunctionType × (List throw <| ← error "A6" def parseFunction : PState Function := do - push "parseFunction" parseItems ["\"func.func\"", "(", ")"] parseItem "<{" let (name,typ,argAttrs,resAttrs) ← parseFunctionDictionaryAttributes @@ -107,13 +106,9 @@ def parseFunction : PState Function := do parseItem "})" parseItems [":","(",")","->","(",")"] let r : Function := { funcId := name , funcArgAttrs := argAttrs , funcResAttrs := resAttrs , funcType := typ, funcBody := body } - pop "parseFunction" return r def parseFunctions : PState (List Function) := do - push "parseFunctions" - let r ← parseListAuxNoSep "}" parseFunction [] - pop "parseFunctions" - return r + parseListAuxNoSep "}" parseFunction [] end StableHLO.Parsing diff --git a/SHerLOC/Parsing/Identifiers.lean b/SHerLOC/Parsing/Identifiers.lean index 8197243..726cf14 100644 --- a/SHerLOC/Parsing/Identifiers.lean +++ b/SHerLOC/Parsing/Identifiers.lean @@ -9,52 +9,36 @@ import SHerLOC.Parsing.Parser namespace StableHLO.Parsing def parseValueId : PState String := do - push "parseValueId" parseItem "%" - let r ← parseId - pop "parseValueId" - return r + parseId def parseValueIdRes : PState String := do - push "parseValueIdRes" let r ← parseValueId let mut r' := "" if ← isParse ":" then r' ← parseId r' := ":" ++ r' let r := r ++ r' - pop "parseValueIdRes" return r def parseValueIdOpArg : PState String := do - push "parseValueOpArg" let r ← parseValueId let mut r' := "" if ← isParse "#" then r' ← parseId r' := "#" ++ r' let r := r ++ r' - pop "parseValueOpArg" return r def parseFuncId : PState String := do - push "parseFuncId" parseItem "@" - let r ← parseFId - pop "parseFuncId" - return r + parseFId def parseUnusedId : PState String := do - push "parseUnusedId" parseItem "^" - let r ← parseId - pop "parseUnusedId" - return r + parseId def parseAttrId : PState String := do - push "parseAttrId" - let r ← parseId - pop "parseAttrId" - return r + parseId end StableHLO.Parsing diff --git a/SHerLOC/Parsing/Intermediate.lean b/SHerLOC/Parsing/Intermediate.lean index be69df5..2e9f36f 100644 --- a/SHerLOC/Parsing/Intermediate.lean +++ b/SHerLOC/Parsing/Intermediate.lean @@ -7,7 +7,6 @@ import SHerLOC.AST1 import SHerLOC.Parsing.Parser import SHerLOC.Parsing.Identifiers import SHerLOC.Parsing.Types -import SHerLOC.Parsing.Constants namespace StableHLO.Parsing @@ -27,17 +26,13 @@ def parseStableHLORecordFieldValue : PState (StableHLORecordFieldValue) := do return StableHLORecordFieldValue.type type def parseStableHLORecordField : PState (StableHLORecordField) := do - push "parseStableHLORecordField" let name ← parseId parseItem "=" let value ← parseStableHLORecordFieldValue - pop "parseStableHLORecordField" return StableHLORecordField.mk name value def parseRecord : PState (List StableHLORecordField) := do - push "parseRecord" let r ← parseList "<" ">" "," parseStableHLORecordField - pop "parseRecord" return r mutual @@ -85,41 +80,30 @@ mutual throw <| (← error "literal") partial def parseConstant : PState Constant := do - push "parseConstant" let literal ← parseLiteral let mut typ : Option SType := none if ← isParse ":" then typ ← parseType let r : Constant := Constant.mk literal typ - pop "parseConstant" return r partial def parseAttribute : PState Attribute := do - push "parseAttribute" if ← isParse "use_global_device_ids" then report "literal use_global_device_ids" - pop "parseAttribute" return Attribute.mk "use_global_device_ids" <| Constant.mk (Literal.element (ElementLiteral.booleanLiteral BooleanLiteral.true)) none else let id ← parseId parseItem "=" let constant ← parseConstant - pop "parseAttribute" return Attribute.mk id constant partial def parseAttributes : PState (List Attribute) := do - push "parseAttributes" - let r ← parseList "{" "}" "," parseAttribute - pop "parseAttributes" - return r + parseList "{" "}" "," parseAttribute end def parseValueUseList : PState (List ValueId) := do - push "parseValueUseList" - let r ← parseList "(" ")" "," parseValueIdOpArg - pop "parseValueUseList" - return r + parseList "(" ")" "," parseValueIdOpArg def tryParseDictionaryEntry (name : String) (parser : PState T) : PState (Option T) := do if ← is name then @@ -130,9 +114,6 @@ def tryParseDictionaryEntry (name : String) (parser : PState T) : PState (Option else return none def parseDictionaryProperties : PState (List Attribute) := do - push "parseDictionaryProperties" - let r ← parseList "<{" "}>" "," parseAttribute - pop "parseDictionaryProperties" - return r + parseList "<{" "}>" "," parseAttribute end StableHLO.Parsing diff --git a/SHerLOC/Parsing/Modules.lean b/SHerLOC/Parsing/Modules.lean index 2f94ba6..35741eb 100644 --- a/SHerLOC/Parsing/Modules.lean +++ b/SHerLOC/Parsing/Modules.lean @@ -12,7 +12,6 @@ import SHerLOC.Parsing.Intermediate namespace StableHLO.Parsing def parseModule : PState Module := do - push "parseModule" parseItems ["\"builtin.module\"", "(", ")"] let mut name : Option FuncId := none if ← is "<{" then @@ -27,7 +26,6 @@ def parseModule : PState Module := do let r : Module := { modId := name, modAttrs := [], modFuncs := [] } parseItems ["}",")"] parseItems [":","(",")","->","(",")"] - pop "parseModule" return r let region ← parseFunctions parseItems ["}",")"] @@ -36,23 +34,18 @@ def parseModule : PState Module := do attributes ← parseAttributes parseItems [":","(",")","->","(",")"] let r : Module := { modId := name, modAttrs := attributes, modFuncs := region } - pop "parseModule" return r else let r : Module := { modId := name, modAttrs := [], modFuncs := [] } - pop "parseModule" return r partial def parseModules : PState (List Module) := do - push "parseModules" let done ← done? if done then - pop "parseModules" return [] else let mod ← parseModule let mods ← parseModules - pop "parseModules" return mod :: mods diff --git a/SHerLOC/Parsing/Numbers.lean b/SHerLOC/Parsing/Numbers.lean index 7b892ce..8464230 100644 --- a/SHerLOC/Parsing/Numbers.lean +++ b/SHerLOC/Parsing/Numbers.lean @@ -15,7 +15,6 @@ def parseBooleanLiteral : PState BooleanLiteral := do throw <| ← error "Boolean literal" def parseIntegerLiteral : PState IntegerLiteral := do - push "parseIntegerLiteral" let mut sign := Sign.plus if ← isParse "+" then sign := Sign.plus else if ← isParse "-" then sign := Sign.minus @@ -26,19 +25,16 @@ def parseIntegerLiteral : PState IntegerLiteral := do nat ← parseDecimal if let some v := nat then let parseResult := { sign := sign , decimal := v } - pop "parseIntegerLiteral" return parseResult else throw <| ← error "Integer literal" def parseFloatLiteral : PState FloatLiteral := do - push "parseFloatLiteral" let mut sign := Sign.plus if ← isParse "+" then sign := Sign.plus else if ← isParse "-" then sign := Sign.minus if ← is "0x" then let nat ← parseHexaDecimal - pop "parseFloatLiteral" return FloatLiteral.hexaDecimal nat else let nat ← parseDecimal @@ -58,18 +54,15 @@ def parseFloatLiteral : PState FloatLiteral := do fractionalPart := fractionalPart, scientificPart := scientificPart } - pop "parseFloatLiteral" return parseResult def parseComplexLiteral : PState ComplexLiteral := do - push "parseComplexLiteral" parseItem "(" let realPart ← parseFloatLiteral parseItem "," let imaginaryPart ← parseFloatLiteral parseItem ")" let parseResult := { real := realPart, imaginary := imaginaryPart } - pop "parseComplexLiteral" return parseResult def parseElementLiteral : PState ElementLiteral := do @@ -85,46 +78,33 @@ def parseElementLiteral : PState ElementLiteral := do throw <| ← error "Element literal" def parseDenseElements (closingMark : String) : PState (List ElementLiteral) := do - push "parseDenseElements" - let r ← parseListAux closingMark "," parseElementLiteral - pop "parseDenseElements" - return r + parseListAux closingMark "," parseElementLiteral partial def parseDenseLiteral : PState DenseLiteral := do - push "parseDenseLiteral" if ← is "[" then let denseDimension ← parseList "[" "]" "," parseDenseLiteral - pop "parseDenseLiteral" return DenseLiteral.denseDimension denseDimension else let denseElements ← parseDenseElements "]" - pop "parseDenseLiteral" return DenseLiteral.denseElements denseElements def parseTensorLiteral : PState TensorLiteral := do - push "parseTensorLiteral" parseItem "dense" parseItem "<" if ← is "[" then let denseLiteral ← parseDenseLiteral parseItem ">" - pop "parseTensorLiteral" return denseLiteral else let denseElements ← parseDenseElements ">" let denseLiteral := DenseLiteral.denseElements denseElements parseItem ">" - pop "parseTensorLiteral" return denseLiteral def parseStringLiteral : PState String := do - push "parseStringLiteral" - let r ← parseString - pop "parseStringLiteral" - return r + parseString def parseComparisonDirection : PState ComparisonDirection := do - push "parseComparisonDirection" let mut r := none if ← isParse "EQ" then r := ComparisonDirection.eq if ← isParse "NE" then r := ComparisonDirection.ne @@ -133,89 +113,73 @@ def parseComparisonDirection : PState ComparisonDirection := do if ← isParse "LE" then r := ComparisonDirection.le if ← isParse "LT" then r := ComparisonDirection.lt if let some res := r then - pop "parseComparisonDirection" return res else throw <| ← error "comparison direction" def parseCompareType : PState CompareType := do - push "parseCompareType" let mut r := none if ← isParse "FLOAT" then r := CompareType.float if ← isParse "TOTALORDER" then r := CompareType.totalOrder if ← isParse "SIGNED" then r := CompareType.signed if ← isParse "UNSIGNED" then r := CompareType.unsigned if let some res := r then - pop "parseCompareType" return res else throw <| ← error "compaare type" def parsePrecisionConfig : PState PrecisionConfig := do - push "parsePrecisionConfig" let mut r := none if ← isParse "DEFAULT" then r := PrecisionConfig.default if ← isParse "HIGHEST" then r := PrecisionConfig.highest if ← isParse "HIGH" then r := PrecisionConfig.high if let some res := r then - pop "parsePrecisionConfig" return res else throw <| ← error "precision config" def parseFftType : PState FftType := do - push "parseFftType" let mut r := none if ← isParse "FFT" then r := FftType.fft if ← isParse "IFFT" then r := FftType.ifft if ← isParse "RFFT" then r := FftType.rfft if ← isParse "IRFFT" then r := FftType.irfft if let some res := r then - pop "parseFftType" return res else throw <| ← error "FFT type" def parseChannelType : PState ChannelType := do - push "parseChannelType" let mut r := none if ← isParse "DEVICE_TO_DEVICE" then r := ChannelType.deviceToDevice if ← isParse "HOST_TO_DEVICE" then r := ChannelType.hostToDevice if let some res := r then - pop "parseChannelType" return res else throw <| ← error "channel type" def parseRngDistribution : PState RngDistribution := do - push "parseRngDistribution" let mut r := none if ← isParse "UNIFORM" then r := RngDistribution.uniform if ← isParse "NORMAL" then r := RngDistribution.normal if let some res := r then - pop "parseRngDistribution" return res else throw <| ← error "rng distribution" def parseRngAlgorithm : PState RngAlgorithm := do - push "parseRngAlgorithm" let mut r := none if ← isParse "DEFAULT" then r := RngAlgorithm.default if ← isParse "THREE_FRY" then r := RngAlgorithm.threeFry if ← isParse "PHILOX" then r := RngAlgorithm.philox if let some res := r then - pop "parseRngAlgorithm" return res else throw <| ← error "rng algorithm" def parseTransposeA : PState TransposeA := do - push "parseTransposeA" let mut r := none if ← isParse "NO_TRANSPOSE" then r := TransposeA.noTranspose if ← isParse "TRANSPOSE" then r := TransposeA.transpose if ← isParse "ADJOINT" then r := TransposeA.adjoint if let some res := r then - pop "parseTransposeA" return res else throw <| ← error "tranpose annotation" def parseEnumLiteral : PState EnumLiteral := do - push "parseEnumLiteral" parseItem "<" let mut r := none if ← isParse "comparison_direction" then r := EnumLiteral.comparisonDirection <| ← parseComparisonDirection @@ -228,7 +192,6 @@ def parseEnumLiteral : PState EnumLiteral := do if ← isParse "transpose" then r := EnumLiteral.transposeA <| ← parseTransposeA if let some res := r then parseItem ">" - pop "parseEnumLiteral" return res else throw <| ← error "enumeration" @@ -249,7 +212,6 @@ def parseArrayLiteral : PState ArrayLiteral := do throw <| ← error "array literal" def parseConvolutionMode : PState ConvolutionMode := do - push "parseConvolutionMode" let mut r := none if (← isParse "o") then r := ConvolutionMode.o else if (← isParse "f") then r := ConvolutionMode.f @@ -258,18 +220,13 @@ def parseConvolutionMode : PState ConvolutionMode := do else if (← isParse "1") then r := ConvolutionMode.one else if (← isParse "b") then r := ConvolutionMode.b else if (← isParse "2") then r := ConvolutionMode.two - pop "parseConvolutionMode" if let some res := r then return res else throw <| ← error "convolution mode" def parseConvolutionModes : PState (List ConvolutionMode) := do - push "parseConvolutionModes" - let r ← parseList "[" "]" "," parseConvolutionMode - pop "parseConvolutionModes" - return r + parseList "[" "]" "," parseConvolutionMode def parseConvolution : PState Convolution := do - push "parseConvolution" parseItem "<" let lhs ← parseConvolutionModes parseItem "x" @@ -277,7 +234,6 @@ def parseConvolution : PState Convolution := do parseItem "->" let result ← parseConvolutionModes parseItem ">" - pop "parseConvolution" return { lhs, rhs, result } end StableHLO.Parsing diff --git a/SHerLOC/Parsing/Operations.lean b/SHerLOC/Parsing/Operations.lean index c48e5c9..ee51d7c 100644 --- a/SHerLOC/Parsing/Operations.lean +++ b/SHerLOC/Parsing/Operations.lean @@ -5,43 +5,33 @@ Authors: Jean-Baptiste Tristan -/ import SHerLOC.AST1 import SHerLOC.Parsing.Parser -import SHerLOC.Parsing.Constants import SHerLOC.Parsing.Identifiers import SHerLOC.Parsing.Intermediate namespace StableHLO.Parsing def parseOpOutputs : PState (List ValueId) := do - push "parseOpOutputs" let r ← parseListAux "=" "," parseValueIdRes - pop "parseOpOutputs" return r def parseInputFuncInput : PState FuncInput := do - push "parseInputFuncInput" let id ← parseValueId parseItem ":" let typ ← parseValueType - pop "parseInputFuncInput" return { id := id , typ := typ } def parseInputFuncInputs : PState (List FuncInput) := do - push "parseInputFuncInputs" let r ← parseList "(" ")" "," parseInputFuncInput - pop "parseInputFuncInputs" return r def parseReturn : PState Operation := do - push "parseReturn" let arguments ← parseValueUseList parseItem ":" let functiontype ← parseFunctionType let parseResult := Operation.return arguments functiontype - pop "parseReturn" return parseResult def parseCall (outputs : List ValueId) : PState Operation := do - push "parseCall" parseItem "\"func.call\"" let arguments ← parseValueUseList parseItem "<{" @@ -52,7 +42,6 @@ def parseCall (outputs : List ValueId) : PState Operation := do parseItem ":" let typ ← parseFunctionType let r := Operation.call callee arguments outputs typ - pop "parseCall" return r def parseOpCode : PState OpCode := do @@ -175,7 +164,6 @@ def parseOpCode : PState OpCode := do mutual partial def parseInputFunc : PState InputFunc := do - push "parseInputFunc" parseItem "{" let mut funcInputs : List FuncInput := [] if ← is "^" then @@ -184,23 +172,17 @@ mutual parseItem ":" let body ← parseInputFuncBody parseItem "}" - pop "parseInputFunc" return InputFunc.mk funcInputs body partial def parseOpInputFuncs : PState (List InputFunc) := do - push "parseOpInputFuncs" let r ← parseList "(" ")" "," parseInputFunc - pop "parseOpInputFuncs" return r partial def parseOperationDictionaryAttributes : PState (List Attribute) := do - push "parseOperationDictionaryAttributes" let r ← parseList "<{" "}>" "," parseAttribute - pop "parseOperationDictionaryAttributes" return r partial def parseOperationBasic (op : OpCode) (opOutputs : List ValueId) : PState Operation := do - push "parseOperationBasic" let opInputValues ← parseValueUseList let mut opInputAttrs := [] if ← is "<{" then @@ -211,11 +193,9 @@ mutual parseItem ":" let functiontype ← parseFunctionType let operation := Operation.stablehlo op opInputValues opInputFuncs opInputAttrs opOutputs functiontype - pop "parseOperationBasic" return operation partial def parseOtherDialect (opOutputs : List ValueId) : PState Operation := do - push "parseOtherDialect" let name ← parseString report s!"undocumented operation: {name}" let opInputValues ← parseValueUseList @@ -228,7 +208,6 @@ mutual parseItem ":" let functiontype ← parseFunctionType let operation := Operation.other name opInputValues opInputFuncs opInputAttrs opOutputs functiontype - pop "parseOtherDialect" return operation partial def parseStableHLO (opOutputs : List ValueId) : PState Operation := do @@ -357,14 +336,11 @@ partial def parseStableHLO (opOutputs : List ValueId) : PState Operation := do | OpCode.xor => parseOperationBasic OpCode.xor opOutputs partial def parseOperation : PState Operation := do - push "parseOperation" if ← isParse "\"func.return\"" then let r ← parseReturn - pop "parseOperation" return r if ← isParse "\"stablehlo.return\"" then let r ← parseReturn - pop "parseOperation" return r let mut opOutputs := [] if ← is "%" then @@ -372,33 +348,26 @@ partial def parseStableHLO (opOutputs : List ValueId) : PState Operation := do parseItem "=" if ← is "\"func.call\"" then let r ← parseCall opOutputs - pop "parseOperation" return r if ← is "\"check." then let r ← parseOtherDialect opOutputs - pop "parseOperation" return r if ← is "\"interpreter." then let r ← parseOtherDialect opOutputs - pop "parseOperation" return r if ← is "\"chlo." then let r ← parseOtherDialect opOutputs - pop "parseOperation" return r let operation ← parseStableHLO opOutputs - pop "parseOperation" return operation partial def parseInputFuncBody : PState (List Operation) := do - push "parseInputFuncBody" let r ← parseListAuxNoSep "}" parseOperation [] - pop "parseInputFuncBody" return r end diff --git a/SHerLOC/Parsing/Parser.lean b/SHerLOC/Parsing/Parser.lean index 6b4cdf4..49cb186 100644 --- a/SHerLOC/Parsing/Parser.lean +++ b/SHerLOC/Parsing/Parser.lean @@ -263,32 +263,6 @@ def parseString : PState String := do parseItem "\"" return token -def push (parser : String) : PState Unit := do - let st ← get - let traceItem : Trace := { startLine := st.lineNumber, startColumn := st.columnNumber, parser } - set { st with trace := traceItem :: st.trace } - -def indent (n : Nat) : String := Id.run do - let mut token := "" - for _ in [:n] do - token := token.push ' ' - return token - -def pop (parser : String) : PState Unit := do - let st ← get - if let some tail := st.trace.tail? then - let head := st.trace.head! - if head.parser = parser then - let derivation : Derivation := { - startLine := head.startLine, - startColumn := head.startColumn, - endLine := st.lineNumber, - endColumn := st.columnNumber, - parser := (indent tail.length) ++ parser } - set {st with trace := tail, derivations := derivation :: st.derivations } - else panic! s!"Trace mismatch: expected {parser} but found {head}" - else panic! "More pops than pushes, some parser is missing its push" - partial def parseListOneorMoreAux (separator : String) (parse : PState T) (acc : List T) : PState (List T) := do if ← isParse separator then parseListOneorMoreAux separator parse ((← parse) :: acc) diff --git a/SHerLOC/Parsing/Types.lean b/SHerLOC/Parsing/Types.lean index b319097..2406a35 100644 --- a/SHerLOC/Parsing/Types.lean +++ b/SHerLOC/Parsing/Types.lean @@ -10,7 +10,6 @@ import SHerLOC.Parsing.Numbers namespace StableHLO.Parsing def tryParseIntegerType : PState (Option IntegerType) := do - push "tryParseIntegerType" let mut r : Option IntegerType := none if ← isChar 'i' then { if ← isParse "i32" then r := some { sign := Signedness.signed , size := IntegerSize.b32 } @@ -26,16 +25,13 @@ def tryParseIntegerType : PState (Option IntegerType) := do if ← isParse "ui4" then r := some { sign := Signedness.unsigned , size := IntegerSize.b4 } if ← isParse "ui8" then r := some { sign := Signedness.unsigned , size := IntegerSize.b8 } if ← isParse "ui16" then r := some { sign := Signedness.unsigned , size := IntegerSize.b16 } - pop "tryParseIntegerType" return r def parseIntegerType : PState IntegerType := do - push "parseIntegerType" - if let some r ← tryParseIntegerType then pop "parseIntegerType" ; return r + if let some r ← tryParseIntegerType then return r else throw <| ← error "Integer type" def tryParseFloatType : PState (Option FloatType) := do - push "tryParseFloatType" let mut r : Option FloatType := none if ← isChar 'f' then { if ← isParse "f16" then r := some FloatType.f16 @@ -51,66 +47,53 @@ def tryParseFloatType : PState (Option FloatType) := do } if ← isParse "bf16" then r := some FloatType.bf16 if ← isParse "tf32" then r := some FloatType.tf32 - pop "tryParseFloatType" return r def parseFloatType : PState FloatType := do - push "parseFloatType" - if let some r ← tryParseFloatType then pop "parseFloatType"; return r + if let some r ← tryParseFloatType then return r else throw <| ← error "Float type" def parseNumberType : PState NumberType := do - push "parseNumberType" - if let some r ← tryParseIntegerType then pop "parseNumberType"; return NumberType.integerType r - else if let some r ← tryParseFloatType then pop "parseNumberType"; return NumberType.floatType r + if let some r ← tryParseIntegerType then return NumberType.integerType r + else if let some r ← tryParseFloatType then return NumberType.floatType r else throw <| ← error "Number type" def parseComplexElementType : PState ComplexType := do - push "parseComplexElementType" - if ← isParse "f32" then pop "parseComplexElementType"; return ComplexType.f32 - else if ← isParse "f64" then pop "parseComplexElementType"; return ComplexType.f64 + if ← isParse "f32" then return ComplexType.f32 + else if ← isParse "f64" then return ComplexType.f64 else throw <| ← error "Complex element type" def parseComplexType : PState ComplexType := do - push "parseComplexType" parseItem "complex" parseItem "<" let t ← parseComplexElementType parseItem ">" - pop "parseComplexType" return t def tryParseDimensionSize : PState (Option DimensionSize) := do - push "parseDimensionSize" let mut r := none if (← isDigit) then r := some <| DimensionSize.known <| ← parseDecimal if (← isParse "?") then r := some <| DimensionSize.unknown - pop "parseDimensionSize" return r partial def parseShape : PState (List DimensionSize) := do - push "parseShape" if let some dim ← tryParseDimensionSize then parseItem "x" let dims ← parseShape - pop "parseShape" return dim :: dims else - pop "parseShape" return [] def parseTensorElementType : PState TensorElementType := do - push "parseTensorElementType" - if let some r ← tryParseIntegerType then pop "parseTensorElementType"; return TensorElementType.integerType r - if ← isParse "i1" then pop "parseTensorElementType"; return TensorElementType.booleanType - if ← is "complex" then pop "parseTensorElementType"; return TensorElementType.complexType <| ← parseComplexType - if let some r ← tryParseFloatType then pop "parseTensorElementType"; return TensorElementType.floatType r + if let some r ← tryParseIntegerType then return TensorElementType.integerType r + if ← isParse "i1" then return TensorElementType.booleanType + if ← is "complex" then return TensorElementType.complexType <| ← parseComplexType + if let some r ← tryParseFloatType then return TensorElementType.floatType r throw <| ← error "TensorElementType" def parseQuantizationParameter : PState QuantizationParameter := do - push "parseQuantizationParameter" let quantizationScale ← parseFloatLiteral let mut quantizationZeroPoint := { sign := Sign.plus , decimal := 0 } if (← isParse ":") then @@ -119,22 +102,17 @@ def parseQuantizationParameter : PState QuantizationParameter := do { quantizationScale := quantizationScale, quantizationZeroPoint := quantizationZeroPoint } - pop "parseQuantizationParameter" return parseResult def parseQuantizationParameters : PState (List QuantizationParameter) := do - push "parseQuantizationParameters" if ← is "{" then let quantizationParameters ← parseList "{" "}" "," parseQuantizationParameter - pop "parseQuantizationParameters" return quantizationParameters else let quantizationParameter ← parseQuantizationParameter - pop "parseQuantizationParameters" return [quantizationParameter] def parseQuantizedTensorElementType : PState QuantizedTensorElementType := do - push "parseQuantizedTensorElementType" parseItem "!quant.uniform" parseItem "<" let quantizationStorageType ← parseIntegerType @@ -163,92 +141,70 @@ def parseQuantizedTensorElementType : PState QuantizedTensorElementType := do { quantizationBasics := quantizationBasics quantizationParameters := quantizationParameters } - pop "parseQuantizedTensorElementType" return parseResult def parseTensorElementTypeGen : PState TensorElementTypeGen := do - push "parseTensorElementTypeGen" if ← is "!quant.uniform" then let quantizedTensorElementType ← parseQuantizedTensorElementType - pop "parseTensorElementTypeGen" return TensorElementTypeGen.quantized quantizedTensorElementType else let tensorElementType ← parseTensorElementType - pop "parseTensorElementTypeGen" return TensorElementTypeGen.classic tensorElementType def parseTensorType : PState TensorType := do - push "parseTensorType" parseItem "tensor" parseItem "<" let shape ← parseShape let tensorElementTypeGen ← parseTensorElementTypeGen parseItem ">" - pop "parseTensorType" return { shape := shape, tensorElementTypeGen := tensorElementTypeGen } def parseTokenType : PState ValueType := do - push "parseTokenType" parseItem "!stablehlo.token" - pop "parseTokenType" return ValueType.tokenType mutual partial def parseTupleType : PState ValueType := do - push "parseTupleType" parseItem "tuple" let TupleElementTypes ← parseList "<" ">" "," parseValueType - pop "parseTupleType" return ValueType.tupleType TupleElementTypes partial def parseValueType : PState ValueType := do - push "parseValueType" - if ← is "tensor" then pop "parseValueType"; return ValueType.tensorType <| ← parseTensorType + if ← is "tensor" then return ValueType.tensorType <| ← parseTensorType else if ← is "tuple" then let r ← parseTupleType - pop "parseValueType" return r else if ← is "!stablehlo.token" then let r ← parseTokenType - pop "parseValueType" return r else throw <| ← error "Value Type" end --- Temporary? Mulitple results? def parseValueTypesOutput : PState (List ValueType) := do - push "parseValueTypesOutput" let mut valueTypes : List ValueType := [] if ← is "(" then valueTypes ← parseList "(" ")" "," parseValueType else let r ← parseValueType valueTypes := [r] - pop "parseValueTypesOutput" return valueTypes def parseValueTypes : PState (List ValueType) := do - push "parseValueTypes" - let r ← parseList "(" ")" "," parseValueType - pop "parseValueTypes" - return r + parseList "(" ")" "," parseValueType + def parseFunctionType : PState FunctionType := do - push "parseFunctionTypeLong" let inputTypes ← parseValueTypes parseItem "-" parseItem ">" let outputType ← parseValueTypesOutput - pop "parseFunctionTypeLong" return { domain := inputTypes, range := outputType } def parseStringType : PState NonValueType := do - push "parseStringType" parseItem "string" - pop "parseStringType" return NonValueType.stringType def parseType : PState SType := do