From 612de54a6a58bb57e728b2b6e8ef997ffe6a703f Mon Sep 17 00:00:00 2001 From: Niko Strijbol Date: Thu, 13 Jun 2024 15:24:44 +0200 Subject: [PATCH] Pass source code location to oracles --- tested/dsl/schema-strict.json | 14 +++ tested/dsl/schema.json | 14 +++ tested/dsl/translate_parser.py | 12 ++- tested/oracles/common.py | 14 ++- tested/oracles/programmed.py | 24 ++++- tested/testsuite.py | 95 +++++++++++++------ tests/exercises/lotto/evaluation/evaluator.py | 31 ++++++ .../one-programmed-analysis-correct.yaml | 10 ++ ...-programmed-analysis-without-language.yaml | 9 ++ .../one-programmed-analysis-wrong.yaml | 10 ++ tests/test_oracles_builtin.py | 16 +++- tests/test_oracles_programmed.py | 53 +++++++++++ 12 files changed, 258 insertions(+), 44 deletions(-) create mode 100644 tests/exercises/lotto/evaluation/one-programmed-analysis-correct.yaml create mode 100644 tests/exercises/lotto/evaluation/one-programmed-analysis-without-language.yaml create mode 100644 tests/exercises/lotto/evaluation/one-programmed-analysis-wrong.yaml diff --git a/tested/dsl/schema-strict.json b/tested/dsl/schema-strict.json index 6808880a..cc9d7baa 100644 --- a/tested/dsl/schema-strict.json +++ b/tested/dsl/schema-strict.json @@ -530,6 +530,13 @@ "items" : { "$ref" : "#/definitions/yamlValueOrPythonExpression" } + }, + "languages": { + "type" : "array", + "description" : "Which programming languages are supported by this oracle.", + "items" : { + "$ref" : "#/definitions/programmingLanguage" + } } } } @@ -585,6 +592,13 @@ "items" : { "$ref" : "#/definitions/yamlValueOrPythonExpression" } + }, + "languages": { + "type" : "array", + "description" : "Which programming languages are supported by this oracle.", + "items" : { + "$ref" : "#/definitions/programmingLanguage" + } } } } diff --git a/tested/dsl/schema.json b/tested/dsl/schema.json index 1a4bc381..e83bfaa4 100644 --- a/tested/dsl/schema.json +++ b/tested/dsl/schema.json @@ -530,6 +530,13 @@ "items" : { "$ref" : "#/definitions/yamlValueOrPythonExpression" } + }, + "languages": { + "type" : "array", + "description" : "Which programming languages are supported by this oracle.", + "items" : { + "$ref" : "#/definitions/programmingLanguage" + } } } } @@ -585,6 +592,13 @@ "items" : { "$ref" : "#/definitions/yamlValueOrPythonExpression" } + }, + "languages": { + "type" : "array", + "description" : "Which programming languages are supported by this oracle.", + "items" : { + "$ref" : "#/definitions/programmingLanguage" + } } } } diff --git a/tested/dsl/translate_parser.py b/tested/dsl/translate_parser.py index 51e69b69..e504a94e 100644 --- a/tested/dsl/translate_parser.py +++ b/tested/dsl/translate_parser.py @@ -91,6 +91,12 @@ class ReturnOracle(dict): ) +def _convert_language_dictionary( + original: dict[str, str] +) -> dict[SupportedLanguage, str]: + return {SupportedLanguage(k): v for k, v in original.items()} + + def _ensure_trailing_newline(text: str) -> str: if text and not text.endswith("\n"): return text + "\n" @@ -374,11 +380,13 @@ def _convert_custom_check_oracle(stream: dict) -> CustomCheckOracle: cv = _convert_yaml_value(v) assert isinstance(cv, Value) converted_args.append(cv) + languages = stream.get("languages") return CustomCheckOracle( function=EvaluationFunction( file=stream["file"], name=stream.get("name", "evaluate") ), arguments=converted_args, + languages=set(languages) if languages else None, ) @@ -503,7 +511,9 @@ def _convert_testcase(testcase: YamlDict, context: DslContext) -> Testcase: message = exception.get("message") assert isinstance(message, str) assert isinstance(exception["types"], dict) - types = cast(dict[str, str], exception["types"]) + types = _convert_language_dictionary( + cast(dict[str, str], exception["types"]) + ) output.exception = ExceptionOutputChannel( exception=ExpectedException(message=message, types=types) ) diff --git a/tested/oracles/common.py b/tested/oracles/common.py index 2f54e44d..c8f67afd 100644 --- a/tested/oracles/common.py +++ b/tested/oracles/common.py @@ -37,16 +37,22 @@ def evaluate_text(configs, channel, actual): from tested.languages.utils import convert_stacktrace_to_clickable_feedback from tested.parsing import fallback_field, get_converter from tested.serialisation import Value -from tested.testsuite import ExceptionOutputChannel, NormalOutputChannel, OutputChannel +from tested.testsuite import ( + ExceptionOutputChannel, + NormalOutputChannel, + OutputChannel, + SupportedLanguage, +) @define class OracleContext: expected: Value actual: Value - execution_directory: str - evaluation_directory: str - programming_language: str + execution_directory: Path + evaluation_directory: Path + submission_path: Path | None + programming_language: SupportedLanguage natural_language: str diff --git a/tested/oracles/programmed.py b/tested/oracles/programmed.py index 938a7c2d..5b76a660 100644 --- a/tested/oracles/programmed.py +++ b/tested/oracles/programmed.py @@ -32,12 +32,18 @@ @define class ConvertedOracleContext: + """ + This is the oracle context that is passed to the actual function. + It should thus remain backwards compatible. + """ + expected: Any actual: Any execution_directory: str evaluation_directory: str programming_language: str natural_language: str + submission_path: str | None @staticmethod def from_context( @@ -46,10 +52,15 @@ def from_context( return ConvertedOracleContext( expected=eval(generate_statement(bundle, context.expected)), actual=eval(generate_statement(bundle, context.actual)), - execution_directory=context.execution_directory, - evaluation_directory=context.evaluation_directory, + execution_directory=str(context.execution_directory.absolute()), + evaluation_directory=str(context.evaluation_directory.absolute()), programming_language=context.programming_language, natural_language=context.natural_language, + submission_path=( + str(context.submission_path.absolute()) + if context.submission_path + else None + ), ) @@ -237,10 +248,13 @@ def evaluate( context = OracleContext( expected=expected, actual=actual, - execution_directory=str(config.context_dir.absolute()), - evaluation_directory=str(config.bundle.config.resources.absolute()), - programming_language=str(config.bundle.config.programming_language), + execution_directory=config.context_dir, + evaluation_directory=config.bundle.config.resources, + programming_language=config.bundle.config.programming_language, natural_language=config.bundle.config.natural_language, + submission_path=( + config.bundle.config.source if channel.oracle.languages else None + ), ) result = _evaluate_programmed(config.bundle, channel.oracle, context) diff --git a/tested/testsuite.py b/tested/testsuite.py index 27564969..7ab44bf9 100644 --- a/tested/testsuite.py +++ b/tested/testsuite.py @@ -11,7 +11,7 @@ from enum import StrEnum, auto, unique from os import path from pathlib import Path -from typing import Any, Literal, Union +from typing import Any, Literal, Protocol, TypeGuard, Union from attrs import define, field @@ -155,17 +155,28 @@ class CustomCheckOracle: independent. The oracle is run through the judge infrastructure to translate values between different programming languages. - Although most programming languages are supported, we recommend using Python, - as TESTed can then apply specific optimisations, meaning it will be faster than - other languages. + Some examples of what is possible with this oracle are sequence alignment + checking, or evaluating non-deterministic return values. Another example is + rendering the value into another representation, such as providing a message + with the rendered SVG image. - Some examples of intended use of this oracle are sequence alignment checking, - or evaluating non-deterministic return values. + While nominally language agnostic, the oracle supports a "languages" property, + which allows limiting for which programming languages the oracle is supported. + If provided, the oracle will be provided with the location of the source code, + enabling language-specific static analysis without having to use language- + specific oracles. """ function: EvaluationFunction arguments: list[Value] = field(factory=list) type: Literal["programmed", "custom_check"] = "custom_check" + languages: set[SupportedLanguage] | None = field(default=None) + + @languages.validator # type: ignore + def validate_languages(self, _, value): + """There should be at least one evaluator.""" + if value and not len(value): + raise ValueError("At least one language is required.") @fallback_field(get_converter(), {"evaluators": "functions"}) @@ -310,7 +321,7 @@ class ExpectedException(WithFeatures): # - Specify nothing, in which case the type is not checked. # - Specify a dictionary mapping programming names to exception names. # These exception names should already be in the right convention. - types: dict[str, str] | None = None + types: dict[SupportedLanguage, str] | None = None def __attrs_post_init__(self): if self.message is None and self.types is None: @@ -319,12 +330,12 @@ def __attrs_post_init__(self): def get_used_features(self) -> FeatureSet: return FeatureSet({Construct.EXCEPTIONS}, types=set(), nested_types=set()) - def get_type(self, language: str) -> str | None: + def get_type(self, language: SupportedLanguage) -> str | None: if not self.types: return None return self.types.get(language) - def readable(self, language: str) -> str: + def readable(self, language: SupportedLanguage) -> str: type_ = self.get_type(language) if self.message and type_: return f"{type_}: {self.message}" @@ -404,6 +415,29 @@ def get_used_features(self) -> FeatureSet: ExitOutput = ExitCodeOutputChannel | IgnoredChannel | EmptyChannel +def _get_text_channel_languages( + output: TextOutput | ValueOutput | ExceptionOutput | FileOutput, +) -> set[SupportedLanguage] | None: + + class _OracleChannel(Protocol): + oracle: Any + + def _is_oracle_channel(value: Any) -> TypeGuard[_OracleChannel]: + return hasattr(value, "oracle") + + if not _is_oracle_channel(output): + return None + + # For language-specific oracles, it is easy. + if isinstance(output.oracle, LanguageSpecificOracle): + return set(output.oracle.functions.keys()) + + if isinstance(output.oracle, CustomCheckOracle): + return output.oracle.languages + + return None + + @define class Output(WithFeatures): """The output channels for a testcase.""" @@ -432,29 +466,28 @@ def get_specific_languages(self) -> set[SupportedLanguage] | None: :return: None if no language-specific stuff is used, a set of supported languages otherwise. """ - languages = None + # Check generic oracles. + individual_languages = [ + _get_text_channel_languages(self.stdout), + _get_text_channel_languages(self.stderr), + _get_text_channel_languages(self.file), + _get_text_channel_languages(self.exception), + _get_text_channel_languages(self.result), + ] + + # Handle special cases if isinstance(self.exception, ExceptionOutputChannel): - if isinstance(self.exception.oracle, LanguageSpecificOracle): - languages = { - SupportedLanguage(x) for x in self.exception.oracle.functions.keys() - } - elif ( - self.exception.exception is not None and self.exception.exception.types - ): - languages = { - SupportedLanguage(x) for x in self.exception.exception.types.keys() - } - if isinstance(self.result, ValueOutputChannel): - if isinstance(self.result.oracle, LanguageSpecificOracle): - langs = { - SupportedLanguage(x) for x in self.result.oracle.functions.keys() - } - if languages is not None: - languages &= langs - else: - languages = langs - - return languages + if self.exception.exception is not None and self.exception.exception.types: + individual_languages.append(set(self.exception.exception.types.keys())) + + # Remove all None elements and merge the rest. + without_none = [x for x in individual_languages if x is not None] + + # If we do not have anything, bail now. + if len(without_none) == 0: + return None + + return set().union(*without_none) @define diff --git a/tests/exercises/lotto/evaluation/evaluator.py b/tests/exercises/lotto/evaluation/evaluator.py index 335c7ac4..440a3cb5 100644 --- a/tests/exercises/lotto/evaluation/evaluator.py +++ b/tests/exercises/lotto/evaluation/evaluator.py @@ -1,4 +1,5 @@ import re +import ast # noinspection PyUnresolvedReferences from evaluation_utils import EvaluationResult, Message @@ -48,3 +49,33 @@ def evaluate(context, count, maximum): if valid: expected = actual return EvaluationResult(valid, expected, actual, messages) + + +def check_for_node(search, context, count, maximum): + assert context.programming_language == "python", "This exercise only supports Python" + # Check if the submission uses a while loop. + with open(context.submission_path, "r") as submission_file: + submission = submission_file.read() + + # This has no error handling, so it is not ready for production. + nodes = ast.walk(ast.parse(submission)) + has_while = any(isinstance(node, search) for node in nodes) + messages = [] + if not has_while: + messages.append("Your code does not use a while loop, which is mandatory.") + eval_result = evaluate(context, count, maximum) + + return EvaluationResult( + eval_result.result and has_while, + eval_result.readable_expected, + eval_result.readable_actual, + eval_result.messages + messages + ) + + +def check_for_while(context, count, maximum): + return check_for_node(ast.While, context, count, maximum) + + +def check_for_for(context, count, maximum): + return check_for_node(ast.For, context, count, maximum) diff --git a/tests/exercises/lotto/evaluation/one-programmed-analysis-correct.yaml b/tests/exercises/lotto/evaluation/one-programmed-analysis-correct.yaml new file mode 100644 index 00000000..384e73ab --- /dev/null +++ b/tests/exercises/lotto/evaluation/one-programmed-analysis-correct.yaml @@ -0,0 +1,10 @@ +- tab: "Feedback" + testcases: + - expression: "loterij(18, 172)" + return: !oracle + oracle: "custom_check" + file: "evaluator.py" + name: "check_for_while" + value: "7 - 37 - 48 - 54 - 70 - 78 - 81 - 90 - 102 - 103 - 113 - 119 - 120 - 137 - 140 - 154 - 157 - 159" + arguments: [18, 172] + languages: ["python"] diff --git a/tests/exercises/lotto/evaluation/one-programmed-analysis-without-language.yaml b/tests/exercises/lotto/evaluation/one-programmed-analysis-without-language.yaml new file mode 100644 index 00000000..af4afb1f --- /dev/null +++ b/tests/exercises/lotto/evaluation/one-programmed-analysis-without-language.yaml @@ -0,0 +1,9 @@ +- tab: "Feedback" + testcases: + - expression: "loterij(18, 172)" + return: !oracle + oracle: "custom_check" + file: "evaluator.py" + name: "check_for_while" + value: "7 - 37 - 48 - 54 - 70 - 78 - 81 - 90 - 102 - 103 - 113 - 119 - 120 - 137 - 140 - 154 - 157 - 159" + arguments: [18, 172] diff --git a/tests/exercises/lotto/evaluation/one-programmed-analysis-wrong.yaml b/tests/exercises/lotto/evaluation/one-programmed-analysis-wrong.yaml new file mode 100644 index 00000000..738ef5bb --- /dev/null +++ b/tests/exercises/lotto/evaluation/one-programmed-analysis-wrong.yaml @@ -0,0 +1,10 @@ +- tab: "Feedback" + testcases: + - expression: "loterij(18, 172)" + return: !oracle + oracle: "custom_check" + file: "evaluator.py" + name: "check_for_for" + value: "7 - 37 - 48 - 54 - 70 - 78 - 81 - 90 - 102 - 103 - 113 - 119 - 120 - 137 - 140 - 154 - 157 - 159" + arguments: [18, 172] + languages: ["python"] diff --git a/tests/test_oracles_builtin.py b/tests/test_oracles_builtin.py index e071a7a3..1db3688b 100644 --- a/tests/test_oracles_builtin.py +++ b/tests/test_oracles_builtin.py @@ -26,6 +26,7 @@ ExpectedException, FileOutputChannel, Suite, + SupportedLanguage, TextOutputChannel, ValueOutputChannel, ) @@ -317,7 +318,10 @@ def test_exception_oracle_correct_message_wrong_type( channel = ExceptionOutputChannel( exception=ExpectedException( message="Test error", - types={"python": "PiefError", "javascript": "PafError"}, + types={ + SupportedLanguage.PYTHON: "PiefError", + SupportedLanguage.JAVASCRIPT: "PafError", + }, ) ) actual_value = get_converter().dumps( @@ -345,7 +349,10 @@ def test_exception_oracle_wrong_message_correct_type( channel = ExceptionOutputChannel( exception=ExpectedException( message="Test error", - types={"python": "PiefError", "javascript": "PafError"}, + types={ + SupportedLanguage.PYTHON: "PiefError", + SupportedLanguage.JAVASCRIPT: "PafError", + }, ) ) @@ -376,7 +383,10 @@ def test_exception_oracle_correct_type_and_message( channel = ExceptionOutputChannel( exception=ExpectedException( message="Test error", - types={"python": "PiefError", "javascript": "PafError"}, + types={ + SupportedLanguage.PYTHON: "PiefError", + SupportedLanguage.JAVASCRIPT: "PafError", + }, ) ) diff --git a/tests/test_oracles_programmed.py b/tests/test_oracles_programmed.py index df70c8ee..166a721c 100644 --- a/tests/test_oracles_programmed.py +++ b/tests/test_oracles_programmed.py @@ -130,3 +130,56 @@ def test_custom_check_function_lotto_wrong( assert len(updates.find_all("start-testcase")) == 1 assert updates.find_status_enum() == ["wrong"] assert len(updates.find_all("append-message")) == 1 + + +def test_custom_check_function_static_analysis_correct( + tmp_path: Path, pytestconfig: pytest.Config +): + conf = configuration( + pytestconfig, + "lotto", + "python", + tmp_path, + "one-programmed-analysis-correct.yaml", + "correct", + ) + result = execute_config(conf) + updates = assert_valid_output(result, pytestconfig) + assert len(updates.find_all("start-testcase")) == 1 + assert updates.find_status_enum() == ["correct"] + assert len(updates.find_all("append-message")) == 0 + + +def test_custom_check_function_static_analysis_wrong( + tmp_path: Path, pytestconfig: pytest.Config +): + conf = configuration( + pytestconfig, + "lotto", + "python", + tmp_path, + "one-programmed-analysis-wrong.yaml", + "correct", + ) + result = execute_config(conf) + updates = assert_valid_output(result, pytestconfig) + assert len(updates.find_all("start-testcase")) == 1 + assert updates.find_status_enum() == ["wrong"] + assert len(updates.find_all("append-message")) == 1 + + +def test_custom_check_function_static_analysis_missing_language( + tmp_path: Path, pytestconfig: pytest.Config +): + conf = configuration( + pytestconfig, + "lotto", + "python", + tmp_path, + "one-programmed-analysis-without-language.yaml", + "correct", + ) + result = execute_config(conf) + updates = assert_valid_output(result, pytestconfig) + assert len(updates.find_all("start-testcase")) == 1 + assert updates.find_status_enum() == ["internal error"]