diff --git a/app/Main.hs b/app/Main.hs index 4491172..ba14c7f 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -41,6 +41,5 @@ runPeter sourceCode = do case result of Left err -> putStrLn $ "Parse error: " ++ show err Right program -> do - -- putStrLn "Parsed program:" -- print program interpret program diff --git a/examples/structs.mmm b/examples/structs.mmm new file mode 100644 index 0000000..ae22f1a --- /dev/null +++ b/examples/structs.mmm @@ -0,0 +1,16 @@ +struct T { + int x; + int y; +} + +T n; +n.x = 2; +println(str(n.x)) + +int i = 3; + +T n1; +n1.y = i; + +println(str(n1.x)) +println(str(n1.y)) diff --git a/peter.cabal b/peter.cabal index bab14ab..af03a49 100644 --- a/peter.cabal +++ b/peter.cabal @@ -44,6 +44,7 @@ library Parser.Program Parser.Space Parser.Statement + Parser.Struct Parser.Type Parser.Variable other-modules: diff --git a/src/AST.hs b/src/AST.hs index d9e36fa..f225477 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE GADTs #-} + module AST (module AST) where type Name = String @@ -5,7 +7,7 @@ type Name = String data Operator = Plus | Minus | Multiply | Divide | Modulus | And | Or | Not | Eq | Neq | Lt | Gt | Le | Ge deriving (Show, Eq) -data Literal = IntLiteral Int | FloatLiteral Float | BoolLiteral Bool | UnitLiteral | StringLiteral String +data Literal = IntLiteral Int | FloatLiteral Float | BoolLiteral Bool | UnitLiteral | StringLiteral String | UndefinedLiteral deriving (Show, Eq) data Atomic = LiteralAtomic Literal | VariableAtomic Name | FunctionCallAtomic Name [Expression] @@ -25,7 +27,7 @@ data Assignment = Assignment Name Expression type Comment = String -data Type = IntType | FloatType | BoolType | UnitType | CustomType Name | StringType +data Type = IntType | FloatType | BoolType | UnitType | CustomType Name | StringType | UndefinedType deriving (Show, Eq) data Control @@ -33,19 +35,25 @@ data Control | WhileControl Expression [Statement] deriving (Show, Eq) +data Struct = Struct Name [VariableDeclaration] + deriving (Show, Eq) + data Statement - = VariableStatement Variable + = VariableDefinitionStatement Variable | AssignmentStatement Assignment | FunctionDefinitionStatement Function | ExpressionStatement Expression | ReturnStatement Expression | ControlStatement Control + | StructStatement Struct + | VariableDeclarationStatement VariableDeclaration deriving (Show, Eq) data Function = Function Name [VariableDeclaration] Type [Statement] deriving (Show, Eq) -data Program = Program [Statement] +data Program where + Program :: [Statement] -> Program deriving (Show, Eq) data BuiltInFuction = Print | Input diff --git a/src/Interpreter/Interpreter.hs b/src/Interpreter/Interpreter.hs index 1aeada3..21db03a 100644 --- a/src/Interpreter/Interpreter.hs +++ b/src/Interpreter/Interpreter.hs @@ -1,7 +1,7 @@ module Interpreter.Interpreter (module Interpreter.Interpreter) where import AST -import Control.Monad (foldM) +import Control.Monad (foldM, foldM_) import qualified Data.Functor import Data.Map.Strict as Map import Interpreter.BuiltIn @@ -16,25 +16,30 @@ interpret (Program statements) = do isValid <- validate (Program statements) if isValid then do - -- putStrLn "Valid program" let correctedStatments = ensureEntryPoint statements - let functionMap = getFunctionMap correctedStatments - -- print correctedStatments - _ <- foldM interpretStatement (InterpretState (ProgramState empty functionMap) Nothing) correctedStatments - -- putStrLn $ "End state: " ++ show endState - return () - else do - putStrLn "Invalid program" + functionMap = getFunctionMap correctedStatments + customTypeMap = getCustomTypeMap correctedStatments + foldM_ interpretStatement (InterpretState (ProgramState empty functionMap customTypeMap) Nothing) correctedStatments + else error "Invalid program" interpretStatement :: InterpretState -> Statement -> IO InterpretState -interpretStatement (InterpretState state _) (VariableStatement (Variable (VariableDeclaration name _) expression)) = do +interpretStatement (InterpretState state _) (VariableDefinitionStatement (Variable (VariableDeclaration name t) expression)) = do (ScopeResult innerVars ret) <- interpretExpression state expression let newState = updateOuterStateV state innerVars - return (InterpretState (updateState newState name ret) Nothing) + newState' = addStructMembersToState newState name t + return (InterpretState (updateState newState' name ret) Nothing) +interpretStatement (InterpretState state _) (VariableDeclarationStatement (VariableDeclaration name t)) = do + let newState = updateState state name (Just UndefinedValue) + newState' = addStructMembersToState newState name t + return (InterpretState newState' Nothing) interpretStatement (InterpretState state _) (AssignmentStatement (Assignment name expression)) = do - (ScopeResult innerVars ret) <- interpretExpression state expression - let newState = updateOuterStateV state innerVars - return (InterpretState (updateState newState name ret) Nothing) + case Map.lookup name (variables state) of + Just UndefinedValue -> error "Can't copy structs" -- TODO: deep copy structs + _ -> do + (ScopeResult innerVars ret) <- interpretExpression state expression + let newState = updateOuterStateV state innerVars + -- TODO: deep copy structs + return (InterpretState (updateState newState name ret) Nothing) interpretStatement (InterpretState state _) (ExpressionStatement expression) = do _ <- interpretExpression state expression return (InterpretState state Nothing) @@ -46,23 +51,8 @@ interpretStatement (InterpretState state _) (ReturnStatement expression) = do return (InterpretState newState ret) interpretStatement (InterpretState state _) (ControlStatement control) = do interpretControl state control - -updateState :: ProgramState -> Name -> Maybe Value -> ProgramState -updateState (ProgramState vars funs) name value = do - case value of - Just v -> ProgramState (Map.insert name v vars) funs - Nothing -> ProgramState vars funs - --- Update variable in outer scope -updateOuterState :: ProgramState -> ProgramState -> ProgramState -updateOuterState (ProgramState outerVars funs) (ProgramState innerVars _) = - ProgramState (Map.unionWithKey (\_ inner _outer -> inner) innerVars outerVars) funs - -updateOuterStateV :: ProgramState -> Map Name Value -> ProgramState -updateOuterStateV (ProgramState outerVars funs) innerVars = - ProgramState - (Map.unionWithKey (\_ inner _outer -> inner) innerVars outerVars) - funs +interpretStatement (InterpretState state _) (StructStatement _) = do + return (InterpretState state Nothing) interpretExpression :: ProgramState -> Expression -> IO ScopeResult interpretExpression state (AtomicExpression atomic) = do @@ -74,15 +64,15 @@ interpretExpression state (OperationExpression left operator right) = do return (ScopeResult (variables state) (Just value)) interpretAtomic :: ProgramState -> Atomic -> IO ScopeResult -interpretAtomic (ProgramState vars _) (LiteralAtomic literal) = do +interpretAtomic (ProgramState vars _ _) (LiteralAtomic literal) = do ret <- interpretLiteral literal return (ScopeResult vars (Just ret)) -interpretAtomic (ProgramState vars _) (VariableAtomic name) = do +interpretAtomic (ProgramState vars _ _) (VariableAtomic name) = do let varValue = Map.lookup name vars return $ case varValue of Just value -> ScopeResult vars (Just value) Nothing -> error $ "Variable not found: " ++ name -interpretAtomic (ProgramState vars funs) (FunctionCallAtomic name args) = do +interpretAtomic (ProgramState vars funs t) (FunctionCallAtomic name args) = do let isBuiltIn = Map.lookup name getAllBuiltIns case isBuiltIn of Just (BuiltIn _ _ fn) -> do @@ -94,23 +84,23 @@ interpretAtomic (ProgramState vars funs) (FunctionCallAtomic name args) = do case fun of Just (FunctionDefinitionStatement (Function _ argDef _ body)) -> do params <- mapExpressionToParam argDef args - let fnScope = ProgramState (Map.union params vars) funs + let fnScope = ProgramState (Map.union params vars) funs t (ScopeResult innerVars ret) <- returnSkipWrapper (InterpretState fnScope Nothing) body True - let (ProgramState newVars _) = updateOuterStateV (ProgramState vars funs) innerVars + let (ProgramState newVars _ _) = updateOuterStateV (ProgramState vars funs t) innerVars return (ScopeResult newVars ret) _ -> error $ "Function not found: " ++ name where getArgValues :: [Expression] -> IO [Value] getArgValues exprs = mapM - (interpretExpression (ProgramState vars funs)) + (interpretExpression (ProgramState vars funs t)) exprs Data.Functor.<&> Prelude.map (\(ScopeResult _ (Just v)) -> v) mapExpressionToParam :: [VariableDeclaration] -> [Expression] -> IO (Map Name Value) mapExpressionToParam [] [] = pure Map.empty mapExpressionToParam (VariableDeclaration n _ : rest) (expression : restExp) = do - (ScopeResult _ (Just val)) <- interpretExpression (ProgramState vars funs) expression + (ScopeResult _ (Just val)) <- interpretExpression (ProgramState vars funs t) expression restMap <- mapExpressionToParam rest restExp return (Map.insert n val restMap) mapExpressionToParam _ _ = error "Invalid number of arguments" @@ -127,32 +117,32 @@ returnSkipWrapper state [] inFunction = else return (ScopeResult (variables (programState state)) Nothing) interpretControl :: ProgramState -> Control -> IO InterpretState -interpretControl (ProgramState vars funs) (IfControl test body elseBody) = do - (BoolValue testValue) <- isTestValue (ProgramState vars funs) test +interpretControl (ProgramState vars funs t) (IfControl test body elseBody) = do + (BoolValue testValue) <- isTestValue (ProgramState vars funs t) test if testValue then do - (ScopeResult innerVars ret) <- returnSkipWrapper (InterpretState (ProgramState vars funs) Nothing) body False - return $ InterpretState (updateOuterStateV (ProgramState vars funs) innerVars) ret + (ScopeResult innerVars ret) <- returnSkipWrapper (InterpretState (ProgramState vars funs t) Nothing) body False + return $ InterpretState (updateOuterStateV (ProgramState vars funs t) innerVars) ret else do case elseBody of Just elseStatements -> do -- TODO: extract cancellable statements function - (ScopeResult innerVars ret) <- returnSkipWrapper (InterpretState (ProgramState vars funs) Nothing) elseStatements False - return $ InterpretState (updateOuterStateV (ProgramState vars funs) innerVars) ret - Nothing -> return $ InterpretState (ProgramState vars funs) Nothing -interpretControl (ProgramState vars funs) (WhileControl test body) = do - (BoolValue testValue) <- isTestValue (ProgramState vars funs) test + (ScopeResult innerVars ret) <- returnSkipWrapper (InterpretState (ProgramState vars funs t) Nothing) elseStatements False + return $ InterpretState (updateOuterStateV (ProgramState vars funs t) innerVars) ret + Nothing -> return $ InterpretState (ProgramState vars funs t) Nothing +interpretControl (ProgramState vars funs t) (WhileControl test body) = do + (BoolValue testValue) <- isTestValue (ProgramState vars funs t) test if testValue then do - (InterpretState innerVars ret) <- foldM interpretStatement (InterpretState (ProgramState vars funs) Nothing) body + (InterpretState innerVars ret) <- foldM interpretStatement (InterpretState (ProgramState vars funs t) Nothing) body case ret of - Just value -> return $ InterpretState (updateOuterState (ProgramState vars funs) innerVars) (Just value) - Nothing -> interpretControl (updateOuterState (ProgramState vars funs) innerVars) (WhileControl test body) - else return $ InterpretState (ProgramState vars funs) Nothing + Just value -> return $ InterpretState (updateOuterState (ProgramState vars funs t) innerVars) (Just value) + Nothing -> interpretControl (updateOuterState (ProgramState vars funs t) innerVars) (WhileControl test body) + else return $ InterpretState (ProgramState vars funs t) Nothing isTestValue :: ProgramState -> Expression -> IO Value -isTestValue (ProgramState vars funs) test = do - (ScopeResult _ (Just testValue)) <- interpretExpression (ProgramState vars funs) test +isTestValue s test = do + (ScopeResult _ (Just testValue)) <- interpretExpression s test if not (isBoolValue testValue) then do error "Control statement test must be a boolean value." else return testValue diff --git a/src/Interpreter/Literal.hs b/src/Interpreter/Literal.hs index 8428e81..4b028a3 100644 --- a/src/Interpreter/Literal.hs +++ b/src/Interpreter/Literal.hs @@ -14,3 +14,5 @@ interpretLiteral UnitLiteral = do return UnitValue interpretLiteral (StringLiteral value) = do return $ StringValue value +interpretLiteral UndefinedLiteral = do + return UndefinedValue diff --git a/src/Interpreter/Manipulator.hs b/src/Interpreter/Manipulator.hs index 536dadb..8882e1d 100644 --- a/src/Interpreter/Manipulator.hs +++ b/src/Interpreter/Manipulator.hs @@ -21,6 +21,16 @@ getFunctionMap inStatments = isFunctionDefinition _ = False getFunctionName (FunctionDefinitionStatement (Function name _ _ _)) = name +getCustomTypeMap :: [Statement] -> Map Name Struct +getCustomTypeMap inStatments = + let rawMap = Map.fromList $ Prelude.map (\item -> (getStructName item, getStruct item)) (Prelude.filter isStructDefinition inStatments) + in rawMap + where + isStructDefinition (StructStatement _) = True + isStructDefinition _ = False + getStructName (StructStatement (Struct name _)) = name + getStruct (StructStatement s) = s + ensureVoidFunctionReturn :: Map Name Statement -> Map Name Statement ensureVoidFunctionReturn = Map.mapWithKey ensureVoidReturn where diff --git a/src/Interpreter/ProgramState.hs b/src/Interpreter/ProgramState.hs index 66a6ed2..cd8b2a1 100644 --- a/src/Interpreter/ProgramState.hs +++ b/src/Interpreter/ProgramState.hs @@ -5,11 +5,16 @@ module Interpreter.ProgramState (module Interpreter.ProgramState) where import AST import Data.Map.Strict as Map -data Value = IntValue Int | FloatValue Float | BoolValue Bool | UnitValue | StringValue String | InterpreterErrorValue String +data Value = IntValue Int | FloatValue Float | BoolValue Bool | UnitValue | StringValue String | InterpreterErrorValue String | UndefinedValue deriving (Show, Eq) data ProgramState where - ProgramState :: {variables :: Map Name Value, functions :: Map Name Statement} -> ProgramState + ProgramState :: + { variables :: Map Name Value, + functions :: Map Name Statement, + types :: Map Name Struct + } -> + ProgramState deriving (Show, Eq) data InterpretState where @@ -18,3 +23,37 @@ data InterpretState where data ScopeResult = ScopeResult (Map Name Value) (Maybe Value) deriving (Show, Eq) + +updateState :: ProgramState -> Name -> Maybe Value -> ProgramState +updateState (ProgramState vars funs t) name value = do + case value of + Just v -> ProgramState (Map.insert name v vars) funs t + Nothing -> ProgramState vars funs t + +-- Update variable in outer scope +updateOuterState :: ProgramState -> ProgramState -> ProgramState +updateOuterState (ProgramState outerVars funs t) (ProgramState innerVars _ _) = + ProgramState (Map.unionWithKey (\_ inner _outer -> inner) innerVars outerVars) funs t + +updateOuterStateV :: ProgramState -> Map Name Value -> ProgramState +updateOuterStateV (ProgramState outerVars funs t) innerVars = + ProgramState + (Map.unionWithKey (\_ inner _outer -> inner) innerVars outerVars) + funs + t + +addStructMembersToState :: ProgramState -> Name -> Type -> ProgramState +addStructMembersToState s varName typeName = do + case typeName of + CustomType structName -> do + let struct = getStruct s structName + case struct of + Just struc -> addStructMembersToState' s varName struc + Nothing -> s + _ -> s + where + getStruct :: ProgramState -> Name -> Maybe Struct + getStruct (ProgramState _ _ t) name = Map.lookup name t + addStructMembersToState' :: ProgramState -> Name -> Struct -> ProgramState + addStructMembersToState' (ProgramState vars funs t) baseName (Struct name members) = + ProgramState vars funs (Map.insert (baseName ++ "." ++ name) (Struct name members) t) diff --git a/src/Parser/Assignment.hs b/src/Parser/Assignment.hs index cd21014..186b19f 100644 --- a/src/Parser/Assignment.hs +++ b/src/Parser/Assignment.hs @@ -10,7 +10,7 @@ import Text.Parsec.String parseAssignment :: Parser Assignment parseAssignment = do - var <- parseName + var <- parseExistingVariableName _ <- spaces' _ <- char '=' _ <- spaces' diff --git a/src/Parser/Expression.hs b/src/Parser/Expression.hs index 7057158..d0c4b0b 100644 --- a/src/Parser/Expression.hs +++ b/src/Parser/Expression.hs @@ -55,4 +55,4 @@ parseAtomic :: Parser Atomic parseAtomic = LiteralAtomic <$> try parseLiteral <|> try parseFunctionCallAtomic - <|> VariableAtomic <$> try parseName + <|> VariableAtomic <$> try parseExistingVariableName diff --git a/src/Parser/Literal.hs b/src/Parser/Literal.hs index 9017e52..753fbc0 100644 --- a/src/Parser/Literal.hs +++ b/src/Parser/Literal.hs @@ -11,6 +11,7 @@ parseLiteral = <|> try parseFloatLiteral <|> try parseIntLiteral <|> try praseStringLiteral + <|> try (char '?' >> return UndefinedLiteral) parseIntLiteral :: Parser Literal parseIntLiteral = do diff --git a/src/Parser/Name.hs b/src/Parser/Name.hs index ea374a9..02a6fe4 100644 --- a/src/Parser/Name.hs +++ b/src/Parser/Name.hs @@ -1,6 +1,7 @@ module Parser.Name (module Parser.Name) where import AST +import Data.List (intercalate) import Text.Parsec import Text.Parsec.String @@ -11,3 +12,14 @@ parseName = do return (fistChar : rest) where startChar = letter <|> char '_' + +parseMemberName :: Parser Name +parseMemberName = do + firstHalf <- parseName + _ <- char '.' + rest <- parseName `sepBy1` string "." + return (firstHalf ++ "." ++ intercalate "." rest) + +parseExistingVariableName :: Parser Name +parseExistingVariableName = + try parseMemberName <|> try parseName diff --git a/src/Parser/Statement.hs b/src/Parser/Statement.hs index 87c4465..cfdd0ce 100644 --- a/src/Parser/Statement.hs +++ b/src/Parser/Statement.hs @@ -7,6 +7,7 @@ import Parser.EndOfLine import Parser.Expression import Parser.Name import Parser.Space +import Parser.Struct import Parser.Type import Parser.Variable import Text.Parsec @@ -15,9 +16,11 @@ import Text.Parsec.String parseStatement :: Parser Statement parseStatement = (ControlStatement <$> try (spaces' *> try parseControl)) + <|> (StructStatement <$> try (spaces' *> try parseStruct)) <|> try parseReturnStatement <|> (FunctionDefinitionStatement <$> try (spaces' *> try parseFunction)) - <|> (VariableStatement <$> try (spaces' *> try parseVariable) <* endOfStatement) + <|> (VariableDefinitionStatement <$> try (spaces' *> try parseVariable) <* endOfStatement) + <|> (VariableDeclarationStatement <$> try (spaces' *> try parseVariableDeclaration) <* endOfStatement) <|> (AssignmentStatement <$> try (spaces' *> try parseAssignment) <* endOfStatement) <|> (ExpressionStatement <$> try (spaces' *> try parseExpression) <* endOfStatement) diff --git a/src/Parser/Struct.hs b/src/Parser/Struct.hs new file mode 100644 index 0000000..e838dd2 --- /dev/null +++ b/src/Parser/Struct.hs @@ -0,0 +1,21 @@ +module Parser.Struct (module Parser.Struct) where + +import AST +import Parser.Name +import Parser.Space +import Parser.Variable +import Text.Parsec +import Text.Parsec.String + +parseStruct :: Parser Struct +parseStruct = do + _ <- string "struct" + _ <- spaces1' + name <- parseName + _ <- spaces' + _ <- char '{' + _ <- spaces' + fields <- parseVariableDeclaration `sepEndBy` (spaces' *> char ';' <* spaces') + _ <- spaces' + _ <- char '}' + return $ Struct name fields diff --git a/src/Parser/Type.hs b/src/Parser/Type.hs index 1562b47..f0e3629 100644 --- a/src/Parser/Type.hs +++ b/src/Parser/Type.hs @@ -7,9 +7,9 @@ import Text.Parsec.String parseType :: Parser Type parseType = - (string "void" >> return UnitType) - <|> (string "int" >> return IntType) - <|> (string "float" >> return FloatType) - <|> (string "bool" >> return BoolType) - <|> (string "str" >> return StringType) + try (string "void" >> return UnitType) + <|> try (string "int" >> return IntType) + <|> try (string "float" >> return FloatType) + <|> try (string "bool" >> return BoolType) + <|> try (string "str" >> return StringType) <|> (CustomType <$> parseName) diff --git a/test/E2E/Interpreter/examples/check_examples.sh b/test/E2E/Interpreter/examples/check_examples.sh index 034c212..0bfaeac 100755 --- a/test/E2E/Interpreter/examples/check_examples.sh +++ b/test/E2E/Interpreter/examples/check_examples.sh @@ -9,6 +9,7 @@ file_paths=( "main_hello_world.mmm:Hello, World!" "short_hello_world.mmm:Hello, World!" "print.mmm:Hello, World!1" + "structs.mmm:23" ) # Function to run program over file and check stdout diff --git a/test/E2E/Interpreter/examples/structs.mmm b/test/E2E/Interpreter/examples/structs.mmm new file mode 100644 index 0000000..816d96b --- /dev/null +++ b/test/E2E/Interpreter/examples/structs.mmm @@ -0,0 +1,15 @@ +struct T { + int x; + int y; +} + +T n; +n.x = 2; +print(str(n.x)) + +int i = 3; + +T n1; +n1.y = i; + +print(str(n1.y)) diff --git a/test/Unit/Parser/Assignment.hs b/test/Unit/Parser/Assignment.hs index d540463..b2862d5 100644 --- a/test/Unit/Parser/Assignment.hs +++ b/test/Unit/Parser/Assignment.hs @@ -32,3 +32,11 @@ testSimple = TestCase $ do "Variable + Number Assignment" (Assignment "k" (OperationExpression (AtomicExpression (VariableAtomic "k")) Plus (AtomicExpression (LiteralAtomic (IntLiteral 1))))) (fromRight emptyTestAssignment (parse parseAssignment "" "k = k + 1")) + assertEqual + "x = 2;" + (Assignment "x" (AtomicExpression (LiteralAtomic (IntLiteral 2)))) + (fromRight emptyTestAssignment (parse parseAssignment "" "x = 2")) + assertEqual + "x.y = 2;" + (Assignment "x.y" (AtomicExpression (LiteralAtomic (IntLiteral 2)))) + (fromRight emptyTestAssignment (parse parseAssignment "" "x.y = 2")) diff --git a/test/Unit/Parser/Program.hs b/test/Unit/Parser/Program.hs index d34f6b1..7e6bff9 100644 --- a/test/Unit/Parser/Program.hs +++ b/test/Unit/Parser/Program.hs @@ -23,7 +23,7 @@ testSimple = TestCase $ do assertEqual "int k = 1;" ( Program - [ VariableStatement + [ VariableDefinitionStatement ( Variable (VariableDeclaration "k" IntType) ( AtomicExpression (LiteralAtomic (IntLiteral 1)) @@ -35,13 +35,13 @@ testSimple = TestCase $ do assertEqual "int k = 1; int j = 2;" ( Program - [ VariableStatement + [ VariableDefinitionStatement ( Variable (VariableDeclaration "k" IntType) ( AtomicExpression (LiteralAtomic (IntLiteral 1)) ) ), - VariableStatement + VariableDefinitionStatement ( Variable (VariableDeclaration "j" IntType) ( AtomicExpression (LiteralAtomic (IntLiteral 2)) @@ -65,7 +65,7 @@ testSimple = TestCase $ do assertEqual "int k = 1; j = 2;" ( Program - [ VariableStatement + [ VariableDefinitionStatement ( Variable (VariableDeclaration "k" IntType) ( AtomicExpression (LiteralAtomic (IntLiteral 1)) diff --git a/test/Unit/Parser/Statement.hs b/test/Unit/Parser/Statement.hs index 2d5bd84..f9eed1c 100644 --- a/test/Unit/Parser/Statement.hs +++ b/test/Unit/Parser/Statement.hs @@ -3,6 +3,7 @@ module Unit.Parser.Statement (allTests) where import AST import Data.Either (fromRight, isRight) import Parser.Statement +import Parser.Struct import Test.HUnit import Text.Parsec (parse) @@ -12,12 +13,13 @@ allTests = TestLabel "functions" testFunctions, TestLabel "return" testReturn, TestLabel "if" testIf, - TestLabel "while" testWhile + TestLabel "while" testWhile, + TestLabel "struct" testStruct ] emptyTestStatement :: Statement emptyTestStatement = - VariableStatement + VariableDefinitionStatement ( Variable (VariableDeclaration "test" IntType) (AtomicExpression (LiteralAtomic (IntLiteral 0))) @@ -31,7 +33,7 @@ testSimple = TestCase $ do (isRight (parse parseStatement "" "")) assertEqual "int i = 1;" - (VariableStatement (Variable (VariableDeclaration "i" IntType) (AtomicExpression (LiteralAtomic (IntLiteral 1))))) + (VariableDefinitionStatement (Variable (VariableDeclaration "i" IntType) (AtomicExpression (LiteralAtomic (IntLiteral 1))))) (fromRight emptyTestStatement (parse parseStatement "" "int i = 1;")) assertEqual "k = 2;" @@ -74,7 +76,7 @@ testFunctions = TestCase $ do "main" [] UnitType - [ VariableStatement (Variable (VariableDeclaration "i" IntType) (AtomicExpression (LiteralAtomic (IntLiteral 1)))), + [ VariableDefinitionStatement (Variable (VariableDeclaration "i" IntType) (AtomicExpression (LiteralAtomic (IntLiteral 1)))), AssignmentStatement (Assignment "i" (AtomicExpression (LiteralAtomic (IntLiteral 2)))) ] ) @@ -148,3 +150,33 @@ testWhile = TestCase $ do ) ) (either (const emptyTestStatement) ControlStatement (parse parseControl "" "while true { return 0; }")) + +testStruct :: Test +testStruct = TestCase $ do + assertEqual + "struct Name {}" + (Struct "Name" []) + (fromRight (Struct "DEFAULT" []) (parse parseStruct "" "struct Name {}")) + assertEqual + "struct Name {}" + (either (const emptyTestStatement) StructStatement (parse parseStruct "" "struct Name {}")) + (fromRight emptyTestStatement (parse parseStatement "" "struct Name {}")) + assertEqual + "struct Name {\ + \ int x;\ + \ str n;\ + \ Name next;\ + \}" + (Struct "Name" [VariableDeclaration "x" IntType, VariableDeclaration "n" StringType, VariableDeclaration "next" (CustomType "Name")]) + ( fromRight + (Struct "DEFAULT" []) + ( parse + parseStruct + "" + "struct Name {\ + \ int x;\ + \ str n;\ + \ Name next;\ + \}" + ) + )