From 34cae092a3e1a9116bfe983652f36b176e381e00 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sat, 20 Jun 2020 01:58:04 +0900 Subject: [PATCH 1/3] Improve renaming variables for name conflicts - stop using random strings - start using proper implementation see https://github.com/online-judge-tools/template-generator/issues/8#issuecomment-629743109 --- onlinejudge_template/analyzer/combined.py | 34 ++++++++------ .../analyzer/simple_patterns.py | 46 +++++++++++++------ 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/onlinejudge_template/analyzer/combined.py b/onlinejudge_template/analyzer/combined.py index 8392527..8b0241b 100644 --- a/onlinejudge_template/analyzer/combined.py +++ b/onlinejudge_template/analyzer/combined.py @@ -51,6 +51,7 @@ def run(resources: AnalyzerResources) -> AnalyzerResult: if resources.html is not None: topcoder_class_definition = onlinejudge_template.analyzer.topcoder.parse_topcoder_class_definition(resources.html, url=resources.url) + # parse the format tree for input input_format: Optional[FormatNode] = None if resources.input_format_string is not None: try: @@ -65,20 +66,7 @@ def run(resources: AnalyzerResources) -> AnalyzerResult: input_samples = [case.input for case in resources.sample_cases] input_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=input_samples) - output_format: Optional[FormatNode] = None - if resources.output_format_string is not None: - try: - output_format = onlinejudge_template.analyzer.parser.run(resources.output_format_string) - except AnalyzerError as e: - logger.error('output analyzer failed: %s', e) - except NotImplementedError as e: - logger.error('output analyzer failed: %s', e) - elif topcoder_class_definition is not None: - output_format = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_output_format(topcoder_class_definition) - elif resources.sample_cases: - output_samples = [case.output for case in resources.sample_cases] - output_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=output_samples) - + # list the variables for input input_variables: Optional[Dict[str, VarDecl]] = None if resources.input_format_string is None and topcoder_class_definition is not None: input_variables = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_input_variables(topcoder_class_definition) @@ -97,6 +85,22 @@ def run(resources: AnalyzerResources) -> AnalyzerResult: except AnalyzerError as e: logger.error('input analyzer failed: %s', e) + # parse the format tree for output + output_format: Optional[FormatNode] = None + if resources.output_format_string is not None: + try: + output_format = onlinejudge_template.analyzer.parser.run(resources.output_format_string) + except AnalyzerError as e: + logger.error('output analyzer failed: %s', e) + except NotImplementedError as e: + logger.error('output analyzer failed: %s', e) + elif topcoder_class_definition is not None: + output_format = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_output_format(topcoder_class_definition) + elif resources.sample_cases: + output_samples = [case.output for case in resources.sample_cases] + output_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=output_samples, env=input_variables) + + # list the variables for output output_variables: Optional[Dict[str, VarDecl]] = None if resources.output_format_string is None and topcoder_class_definition is not None: output_variables = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_output_variables(topcoder_class_definition) @@ -115,10 +119,12 @@ def run(resources: AnalyzerResources) -> AnalyzerResult: except AnalyzerError as e: logger.error('output analyzer failed: %s', e) + # list constants constants: Dict[str, ConstantDecl] = {} if resources.html is not None or resources.sample_cases: constants.update(onlinejudge_template.analyzer.constants.list_constants(html=resources.html, sample_cases=resources.sample_cases)) + # simplify the output format output_type: Optional[OutputType] = None if output_format is not None and output_variables is not None: output_type = onlinejudge_template.analyzer.output_types.analyze_output_type(output_format=output_format, output_variables=output_variables, constants=constants) diff --git a/onlinejudge_template/analyzer/simple_patterns.py b/onlinejudge_template/analyzer/simple_patterns.py index b2b2d97..3a39c2f 100644 --- a/onlinejudge_template/analyzer/simple_patterns.py +++ b/onlinejudge_template/analyzer/simple_patterns.py @@ -26,9 +26,8 @@ """ import functools -import random +import itertools import re -import string from logging import getLogger from typing import * @@ -229,6 +228,9 @@ def _make_tree_pattern_dfs(node: FormatNode) -> Tuple[FormatNode, bool]: def _make_tree_patterns(patterns: List[FormatNode]) -> List[FormatNode]: + """_make_tree_patterns detects patterns which have the variable `n` and arrays with lentgh `n`, and replaces the length of arrays with `n - 1`. + """ + tree_patterns = [] for pattern in patterns: pattern, replaced = _make_tree_pattern_dfs(pattern) @@ -239,6 +241,9 @@ def _make_tree_patterns(patterns: List[FormatNode]) -> List[FormatNode]: @functools.lru_cache(maxsize=None) def list_all_patterns() -> List[Tuple[FormatNode, Dict[str, VarDecl]]]: + """list_all_patterns lists all pre-defined petterns. + """ + patterns: List[FormatNode] = [ *_simple_patterns, *_vertical_simple_patterns, @@ -258,17 +263,24 @@ def list_all_patterns() -> List[Tuple[FormatNode, Dict[str, VarDecl]]]: return results -def _randomize_variables_dfs(node: FormatNode, *, mapping: Dict[str, str]) -> FormatNode: +def _rename_variables_if_conflicts_dfs(node: FormatNode, *, mapping: Dict[str, str], env: Dict[str, VarDecl]) -> FormatNode: def rename(s: str) -> str: for a, b in mapping.items(): s = re.sub(r'\b' + re.escape(a) + r'\b', b, s) return s if isinstance(node, ItemNode): - assert node.name not in mapping - mapping[node.name] = random.choice(string.ascii_lowercase) + random.choice(string.ascii_lowercase) + random.choice(string.ascii_lowercase) - indices = list(map(rename, node.indices)) - return ItemNode(name=mapping[node.name], indices=indices) + assert node.name not in mapping # because there are only such patterns + if node.name not in env: + return node + else: + for i in itertools.count(1): + new_name = node.name + str(i) + if new_name not in env: + mapping[node.name] = new_name + break + indices = list(map(rename, node.indices)) + return ItemNode(name=mapping[node.name], indices=indices) elif isinstance(node, NewlineNode): return node @@ -276,22 +288,30 @@ def rename(s: str) -> str: elif isinstance(node, SequenceNode): items: List[FormatNode] = [] for item in node.items: - items.append(_randomize_variables_dfs(item, mapping=mapping)) + items.append(_rename_variables_if_conflicts_dfs(item, mapping=mapping, env=env)) return SequenceNode(items=items) elif isinstance(node, LoopNode): - body = _randomize_variables_dfs(node.body, mapping=mapping) + body = _rename_variables_if_conflicts_dfs(node.body, mapping=mapping, env=env) return LoopNode(name=node.name, size=rename(node.size), body=body) else: assert False -def randomize_variables(node: FormatNode) -> FormatNode: - return _randomize_variables_dfs(node, mapping={}) +def rename_variables_if_conflicts(node: FormatNode, *, env: Dict[str, VarDecl]) -> FormatNode: + return _rename_variables_if_conflicts_dfs(node, mapping={}, env=env) + + +def guess_format_with_pattern_matching(*, instances: List[bytes], env: Optional[Dict[str, VarDecl]] = None) -> Optional[FormatNode]: + """guess_format_with_pattern_matching guesses a format tree from the strings which match with the format tree, i.e. sample cases. + :param instances: are sample cases. + :param env: is the dict which contains variables already defined. + """ -def guess_format_with_pattern_matching(*, instances: List[bytes]) -> Optional[FormatNode]: + if env is None: + env = {} found: List[FormatNode] = [] for pattern, variables in list_all_patterns(): try: @@ -301,6 +321,6 @@ def guess_format_with_pattern_matching(*, instances: List[bytes]) -> Optional[Fo except FormatMatchError: pass if len(found) == 1: - return randomize_variables(found[0]) + return rename_variables_if_conflicts(found[0], env=env) else: return None From 8a5d3f8f2a4bf378e9939fdcd1ba2ada68348c26 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sat, 20 Jun 2020 03:20:00 +0900 Subject: [PATCH 2/3] Fix bugs of onlinejudge_template/analyzer/output_types.py --- onlinejudge_template/analyzer/output_types.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/onlinejudge_template/analyzer/output_types.py b/onlinejudge_template/analyzer/output_types.py index 534d132..c1022d7 100644 --- a/onlinejudge_template/analyzer/output_types.py +++ b/onlinejudge_template/analyzer/output_types.py @@ -38,13 +38,12 @@ def analyze_output_type(*, output_format: FormatNode, output_variables: Dict[str item1 = node.items[1] if isinstance(item0, ItemNode) and isinstance(item1, NewlineNode): type = decls[item0.name].type - name = 'ans' # item0.name may be randomized if type is not None: if 'YES' in constants and 'NO' in constants and type == VarType.String: - return YesNoOutputType(name=name, yes='YES', no='NO') + return YesNoOutputType(name='ans', yes='YES', no='NO') if 'FIRST' in constants and 'SECOND' in constants and type == VarType.String: - return YesNoOutputType(name=name, yes='FIRST', no='SECOND') - return OneOutputType(name=name, type=type) + return YesNoOutputType(name='ans', yes='FIRST', no='SECOND') + return OneOutputType(name='ans', type=type) # pattern: # x y @@ -87,11 +86,10 @@ def analyze_output_type(*, output_format: FormatNode, output_variables: Dict[str if isinstance(item0, ItemNode) and isinstance(item1, NewlineNode) and isinstance(item3, NewlineNode): if isinstance(item2, LoopNode) and isinstance(item2.body, ItemNode) and item2.size == item0.name and item2.body.indices == [item2.name]: type = decls[item2.body.name].type - name = 'ans' # item2.body.name may be randomized - subscripted_name = _get_variable(decl=decls[name], indices=item2.body.indices, decls=decls) + subscripted_name = _get_variable(decl=decls[item2.body.name], indices=item2.body.indices, decls=decls) counter_name = item2.name if type is not None: - return VectorOutputType(name=name, type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=False) + return VectorOutputType(name='ans', type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=False) # pattern: # n a_1 ... a_n @@ -102,11 +100,10 @@ def analyze_output_type(*, output_format: FormatNode, output_variables: Dict[str if isinstance(item0, ItemNode) and isinstance(item2, NewlineNode): if isinstance(item1, LoopNode) and isinstance(item1.body, ItemNode) and item1.size == item0.name and item1.body.indices == [item1.name]: type = decls[item1.body.name].type - name = 'ans' # item1.body.name may be randomized - subscripted_name = _get_variable(decl=decls[name], indices=item1.body.indices, decls=decls) + subscripted_name = _get_variable(decl=decls[item1.body.name], indices=item1.body.indices, decls=decls) counter_name = item1.name if type is not None: - return VectorOutputType(name=name, type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=False, print_newline_after_item=False) + return VectorOutputType(name='ans', type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=False, print_newline_after_item=False) # pattern: # n @@ -123,10 +120,9 @@ def analyze_output_type(*, output_format: FormatNode, output_variables: Dict[str item4 = node.items[1] if isinstance(item3, ItemNode) and isinstance(item4, NewlineNode) and item2.size == item0.name and item3.indices == [item0.name]: type = decls[item3.name].type - name = 'ans' # item3.name may be randomized - subscripted_name = _get_variable(decl=decls[name], indices=item3.indices, decls=decls) + subscripted_name = _get_variable(decl=decls[item3.name], indices=item3.indices, decls=decls) counter_name = item2.name if type is not None: - return VectorOutputType(name=name, type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=True) + return VectorOutputType(name='ans', type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=True) return None From 477c5f18bb181ecd98f1e00bb55af980f89fe409 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sat, 20 Jun 2020 03:20:31 +0900 Subject: [PATCH 3/3] Update simple_patterns.py to recognize output formats which depend variables in input formats --- onlinejudge_template/analyzer/combined.py | 7 +- onlinejudge_template/analyzer/match.py | 13 ++- .../analyzer/simple_patterns.py | 88 +++++++++++++++++-- tests/analyzer_combined.py | 36 ++++++++ 4 files changed, 135 insertions(+), 9 deletions(-) create mode 100644 tests/analyzer_combined.py diff --git a/onlinejudge_template/analyzer/combined.py b/onlinejudge_template/analyzer/combined.py index 8b0241b..fa32268 100644 --- a/onlinejudge_template/analyzer/combined.py +++ b/onlinejudge_template/analyzer/combined.py @@ -97,8 +97,11 @@ def run(resources: AnalyzerResources) -> AnalyzerResult: elif topcoder_class_definition is not None: output_format = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_output_format(topcoder_class_definition) elif resources.sample_cases: - output_samples = [case.output for case in resources.sample_cases] - output_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=output_samples, env=input_variables) + if input_format is not None and input_variables is not None: + output_format = onlinejudge_template.analyzer.simple_patterns.guess_output_format_with_pattern_matching_using_input_format(instances=resources.sample_cases, input_format=input_format, input_variables=input_variables) + else: + output_samples = [case.output for case in resources.sample_cases] + output_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=output_samples) # list the variables for output output_variables: Optional[Dict[str, VarDecl]] = None diff --git a/onlinejudge_template/analyzer/match.py b/onlinejudge_template/analyzer/match.py index e33f38f..67b4738 100644 --- a/onlinejudge_template/analyzer/match.py +++ b/onlinejudge_template/analyzer/match.py @@ -125,14 +125,23 @@ def _match_format_dfs(node: FormatNode, tokens: List[str], *, variables: Dict[st assert False -def match_format(node: FormatNode, data: str, *, variables: Dict[str, VarDecl]) -> Dict[str, Dict[Tuple[int, ...], Union[int, float, str]]]: +def match_format( + node: FormatNode, + data: str, + *, + variables: Dict[str, VarDecl], + values: Optional[Dict[str, Dict[Tuple[int, ...], Union[int, float, str]]]] = None, +) -> Dict[str, Dict[Tuple[int, ...], Union[int, float, str]]]: """ :raises FormatMatchError: + :param values: is an optional argument to specify pre-defined variables. """ # prepare buffer - values: Dict[str, Dict[Tuple[int, ...], Union[int, float, str]]] = {} + if values is None: + values = {} for name in variables.keys(): + assert name not in values values[name] = {} # tokenize input diff --git a/onlinejudge_template/analyzer/simple_patterns.py b/onlinejudge_template/analyzer/simple_patterns.py index 3a39c2f..d4cdb6f 100644 --- a/onlinejudge_template/analyzer/simple_patterns.py +++ b/onlinejudge_template/analyzer/simple_patterns.py @@ -263,6 +263,25 @@ def list_all_patterns() -> List[Tuple[FormatNode, Dict[str, VarDecl]]]: return results +def list_output_patterns_depending_input_variable(n: str) -> List[FormatNode]: + """list_output_patterns_depending_input_variable lists output patterns which depend input patterns. + + :param n: is the name of the variable which represents the length of array. + """ + + assert n not in ('ans', 'i') + vector_pattern = SequenceNode(items=[ + LoopNode(name='i', size=n, body=ItemNode(name='ans', indices=['i'])), + NewlineNode(), + ]) + vertical_vector_pattern = LoopNode(name='i', size=n, body=SequenceNode(items=[ + ItemNode(name='ans', indices=['i']), + NewlineNode(), + ])) + all_patterns = [vector_pattern, vertical_vector_pattern] + return all_patterns + + def _rename_variables_if_conflicts_dfs(node: FormatNode, *, mapping: Dict[str, str], env: Dict[str, VarDecl]) -> FormatNode: def rename(s: str) -> str: for a, b in mapping.items(): @@ -303,24 +322,83 @@ def rename_variables_if_conflicts(node: FormatNode, *, env: Dict[str, VarDecl]) return _rename_variables_if_conflicts_dfs(node, mapping={}, env=env) -def guess_format_with_pattern_matching(*, instances: List[bytes], env: Optional[Dict[str, VarDecl]] = None) -> Optional[FormatNode]: +def guess_format_with_pattern_matching(*, instances: List[bytes]) -> Optional[FormatNode]: """guess_format_with_pattern_matching guesses a format tree from the strings which match with the format tree, i.e. sample cases. :param instances: are sample cases. - :param env: is the dict which contains variables already defined. """ - if env is None: - env = {} found: List[FormatNode] = [] + + # patterns without variables in the input format for pattern, variables in list_all_patterns(): + pattern = rename_variables_if_conflicts(pattern, env={}) try: for data in instances: match_format(pattern, data.decode(), variables=variables) + except FormatMatchError: + pass + else: + logger.debug('simple pattern found: %s', pattern) found.append(pattern) + + if len(found) == 1: + return found[0] + else: + return None + + +def guess_output_format_with_pattern_matching_using_input_format(*, instances: List[SampleCase], input_format: FormatNode, input_variables: Dict[str, VarDecl]) -> Optional[FormatNode]: + """guess_output_format_with_pattern_matching_using_input_format + + :param instances: are sample cases. + :param input_format: + :param input_variables: + """ + + found: List[FormatNode] = [] + + # patterns without variables in the input format + for pattern, variables in list_all_patterns(): + try: + for data in instances: + match_format(pattern, data.output.decode(), variables=variables) except FormatMatchError: pass + else: + pattern = rename_variables_if_conflicts(pattern, env=input_variables) + logger.debug('simple output pattern found without input variables: %s', pattern) + found.append(pattern) + + # patterns with variables in the input format + for name in ('n', 'N', 'm', 'M', 't', 'T'): + if name in input_variables and input_variables[name].type in (VarType.IndexInt, VarType.ValueInt): + env = dict(input_variables) + env.pop(name) + for pattern in list_output_patterns_depending_input_variable(name): + + # prepare pattern + pattern = rename_variables_if_conflicts(pattern, env=env) + try: + variables = onlinejudge_template.analyzer.variables.list_declared_variables(pattern) + except onlinejudge_template.analyzer.variables.DeclaredVariablesError: + assert False + assert name not in variables + + # try matching + try: + for data in instances: + input_values = match_format(input_format, data.input.decode(), variables=input_variables) + values = {name: input_values[name]} # hide variables other than the `name` + match_format(pattern, data.output.decode(), variables=variables, values=values) + except FormatMatchError as e: + logger.exception(e) + pass + else: + logger.debug('simple output pattern found with input variables: %s', pattern) + found.append(pattern) + if len(found) == 1: - return rename_variables_if_conflicts(found[0], env=env) + return found[0] else: return None diff --git a/tests/analyzer_combined.py b/tests/analyzer_combined.py new file mode 100644 index 0000000..56ad5d2 --- /dev/null +++ b/tests/analyzer_combined.py @@ -0,0 +1,36 @@ +import unittest + +import onlinejudge_template.analyzer.combined as analyzer +from onlinejudge_template.types import * + + +class TestAnalyzerCombined(unittest.TestCase): + """TestAnalyzerCombinedCodeforces is a class for integration tests about analyzers (without network access). + """ + def test_output_format_depending_input_format(self) -> None: + resources = AnalyzerResources( + url='https://atcoder.jp/contests/arc093/tasks/arc093_a', + html=b'...skipped...', + input_format_string='N\r\nA_1 A_2 ... A_N\r\n', + output_format_string=None, + sample_cases=[ + SampleCase(input=b'3\n3 5 -1\n', output=b'12\n8\n10\n'), + SampleCase(input=b'5\n1 1 1 2 0\n', output=b'4\n4\n4\n2\n4\n'), + SampleCase(input=b'6\n-679 -2409 -3258 3095 -3291 -4462\n', output=b'21630\n21630\n19932\n8924\n21630\n19288\n'), + ], + ) + + input_format = SequenceNode(items=[ + ItemNode(indices=[], name='N'), + NewlineNode(), + LoopNode(body=ItemNode(indices=['i + 1'], name='A'), name='i', size='N'), + NewlineNode(), + ]) + output_format = LoopNode(body=SequenceNode(items=[ + ItemNode(indices=['i'], name='ans'), + NewlineNode(), + ]), name='i', size='N') + + analyzed = analyzer.run(resources) + self.assertEqual(str(analyzed.input_format), str(input_format)) + self.assertEqual(str(analyzed.output_format), str(output_format))