Skip to content

Commit

Permalink
improved parsing of dot_general
Browse files Browse the repository at this point in the history
  • Loading branch information
jtristan committed Sep 13, 2024
1 parent 5d66573 commit 36da408
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 33 deletions.
22 changes: 11 additions & 11 deletions SHerLOC/AST1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,6 @@ inductive ArrayLiteral where
| array1 (literal : List BooleanLiteral)
deriving Repr, Inhabited, Nonempty

inductive StableHLORecordFieldValue where
| one (literal : Nat)
| many (literal : List Nat)
deriving Repr, Inhabited, Nonempty

structure StableHLORecordField where
name : String
value : StableHLORecordFieldValue
deriving Repr, Inhabited, Nonempty

inductive ConvolutionMode where
| i
| o
Expand Down Expand Up @@ -265,6 +255,17 @@ inductive SType where

mutual

inductive StableHLORecordFieldValue where
| one (literal : Nat)
| many (literal : List Nat)
| type (literal : FloatType)
| bool (literal : Bool)
deriving Repr, Inhabited, Nonempty

inductive StableHLORecordField where
| mk (name : String) (value : StableHLORecordFieldValue)
deriving Repr, Inhabited, Nonempty

inductive Literal where
| enum (literal : EnumLiteral)
| element (literal : ElementLiteral)
Expand All @@ -276,7 +277,6 @@ mutual
| list (literal : List Literal)
| dictionary (literal : List Attribute)
| array (literal : ArrayLiteral)

deriving Repr, Inhabited, Nonempty

inductive Constant where
Expand Down
33 changes: 31 additions & 2 deletions SHerLOC/Parsing/Intermediate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,35 @@ import SHerLOC.Parsing.Constants

namespace StableHLO.Parsing

def parseStableHLORecordFieldValue : PState (StableHLORecordFieldValue) := do
if (← is "[") then
let value ← parseDecimals
return StableHLORecordFieldValue.many value
else if (← isDigit) then
let value ← parseDecimal
return StableHLORecordFieldValue.one value
else if (← isParse "true") then
return StableHLORecordFieldValue.bool true
else if (← isParse "false") then
return StableHLORecordFieldValue.bool false
else
let type ← parseFloatType
return StableHLORecordFieldValue.type type

def parseStableHLORecordField : PState (StableHLORecordField) := do
push "parseStableHLORecordField"
let name ← parseId
parseItem "="
let value ← parseStableHLORecordFieldValue
pop "parseStableHLORecordField"
return StableHLORecordField.mk name value

def parseRecord : PState (List StableHLORecordField) := do
push "parseRecord"
let r ← parseList "<" ">" "," parseStableHLORecordField
pop "parseRecord"
return r

mutual

partial def parseLiteral : PState Literal := do
Expand All @@ -19,7 +48,7 @@ mutual
return Literal.element <| ElementLiteral.floatLiteral <| ← parseFloatLiteral
if ← isChar 'd' then
return Literal.tensor <| ← parseTensorLiteral
if (← isChar 't') || (← isChar 'f') then
if (← is "tr") || (← is "fa") then
return Literal.element <| ElementLiteral.booleanLiteral <| ← parseBooleanLiteral
if (← isChar '(') then
return Literal.element <| ElementLiteral.complexLiteral <| ← parseComplexLiteral
Expand Down Expand Up @@ -48,7 +77,7 @@ mutual
if ← isChar '@' then
return Literal.func <| ← parseFuncId

throw <| ← error "literal"
throw <| (← error "literal")

partial def parseConstant : PState Constant := do
push "parseConstant"
Expand Down
20 changes: 0 additions & 20 deletions SHerLOC/Parsing/Numbers.lean
Original file line number Diff line number Diff line change
Expand Up @@ -248,26 +248,6 @@ def parseArrayLiteral : PState ArrayLiteral := do
return ArrayLiteral.array1 r
throw <| ← error "array literal"

def parseStableHLORecordFieldValue : PState (StableHLORecordFieldValue) := do
if (← is "[") then
let value ← parseDecimals
return StableHLORecordFieldValue.many value
else
let value ← parseDecimal
return StableHLORecordFieldValue.one value

def parseStableHLORecordField : PState (StableHLORecordField) := do
let name ← parseId
parseItem "="
let value ← parseStableHLORecordFieldValue
return { name, value}

def parseRecord : PState (List StableHLORecordField) := do
push "parseRecord"
let r ← parseList "<" ">" "," parseStableHLORecordField
pop "parseRecord"
return r

def parseConvolutionMode : PState ConvolutionMode := do
push "parseConvolutionMode"
let mut r := none
Expand Down
1 change: 1 addition & 0 deletions SHerLOC/Parsing/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def tryParseFloatType : PState (Option FloatType) := do
if ← isParse "f8E5M2" then r := some FloatType.f8E5M2
}
if ← isParse "bf16" then r := some FloatType.bf16
if ← isParse "tf32" then r := some FloatType.tf32
pop "tryParseFloatType"
return r

Expand Down

0 comments on commit 36da408

Please sign in to comment.