Skip to content

Commit

Permalink
Merge pull request #29 from online-judge-tools/feature/improve-sample…
Browse files Browse the repository at this point in the history
…-guessing

Update simple_patterns.py to recognize output formats which depend variables in input formats
  • Loading branch information
kmyk authored Jun 19, 2020
2 parents 572f2e9 + 477c5f1 commit e03df1e
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 41 deletions.
37 changes: 23 additions & 14 deletions onlinejudge_template/analyzer/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def run(resources: AnalyzerResources) -> AnalyzerResult:
if resources.html is not None:
topcoder_class_definition = onlinejudge_template.analyzer.topcoder.parse_topcoder_class_definition(resources.html, url=resources.url)

# parse the format tree for input
input_format: Optional[FormatNode] = None
if resources.input_format_string is not None:
try:
Expand All @@ -65,20 +66,7 @@ def run(resources: AnalyzerResources) -> AnalyzerResult:
input_samples = [case.input for case in resources.sample_cases]
input_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=input_samples)

output_format: Optional[FormatNode] = None
if resources.output_format_string is not None:
try:
output_format = onlinejudge_template.analyzer.parser.run(resources.output_format_string)
except AnalyzerError as e:
logger.error('output analyzer failed: %s', e)
except NotImplementedError as e:
logger.error('output analyzer failed: %s', e)
elif topcoder_class_definition is not None:
output_format = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_output_format(topcoder_class_definition)
elif resources.sample_cases:
output_samples = [case.output for case in resources.sample_cases]
output_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=output_samples)

# list the variables for input
input_variables: Optional[Dict[str, VarDecl]] = None
if resources.input_format_string is None and topcoder_class_definition is not None:
input_variables = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_input_variables(topcoder_class_definition)
Expand All @@ -97,6 +85,25 @@ def run(resources: AnalyzerResources) -> AnalyzerResult:
except AnalyzerError as e:
logger.error('input analyzer failed: %s', e)

# parse the format tree for output
output_format: Optional[FormatNode] = None
if resources.output_format_string is not None:
try:
output_format = onlinejudge_template.analyzer.parser.run(resources.output_format_string)
except AnalyzerError as e:
logger.error('output analyzer failed: %s', e)
except NotImplementedError as e:
logger.error('output analyzer failed: %s', e)
elif topcoder_class_definition is not None:
output_format = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_output_format(topcoder_class_definition)
elif resources.sample_cases:
if input_format is not None and input_variables is not None:
output_format = onlinejudge_template.analyzer.simple_patterns.guess_output_format_with_pattern_matching_using_input_format(instances=resources.sample_cases, input_format=input_format, input_variables=input_variables)
else:
output_samples = [case.output for case in resources.sample_cases]
output_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=output_samples)

# list the variables for output
output_variables: Optional[Dict[str, VarDecl]] = None
if resources.output_format_string is None and topcoder_class_definition is not None:
output_variables = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_output_variables(topcoder_class_definition)
Expand All @@ -115,10 +122,12 @@ def run(resources: AnalyzerResources) -> AnalyzerResult:
except AnalyzerError as e:
logger.error('output analyzer failed: %s', e)

# list constants
constants: Dict[str, ConstantDecl] = {}
if resources.html is not None or resources.sample_cases:
constants.update(onlinejudge_template.analyzer.constants.list_constants(html=resources.html, sample_cases=resources.sample_cases))

# simplify the output format
output_type: Optional[OutputType] = None
if output_format is not None and output_variables is not None:
output_type = onlinejudge_template.analyzer.output_types.analyze_output_type(output_format=output_format, output_variables=output_variables, constants=constants)
Expand Down
13 changes: 11 additions & 2 deletions onlinejudge_template/analyzer/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,23 @@ def _match_format_dfs(node: FormatNode, tokens: List[str], *, variables: Dict[st
assert False


def match_format(node: FormatNode, data: str, *, variables: Dict[str, VarDecl]) -> Dict[str, Dict[Tuple[int, ...], Union[int, float, str]]]:
def match_format(
node: FormatNode,
data: str,
*,
variables: Dict[str, VarDecl],
values: Optional[Dict[str, Dict[Tuple[int, ...], Union[int, float, str]]]] = None,
) -> Dict[str, Dict[Tuple[int, ...], Union[int, float, str]]]:
"""
:raises FormatMatchError:
:param values: is an optional argument to specify pre-defined variables.
"""

# prepare buffer
values: Dict[str, Dict[Tuple[int, ...], Union[int, float, str]]] = {}
if values is None:
values = {}
for name in variables.keys():
assert name not in values
values[name] = {}

# tokenize input
Expand Down
22 changes: 9 additions & 13 deletions onlinejudge_template/analyzer/output_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ def analyze_output_type(*, output_format: FormatNode, output_variables: Dict[str
item1 = node.items[1]
if isinstance(item0, ItemNode) and isinstance(item1, NewlineNode):
type = decls[item0.name].type
name = 'ans' # item0.name may be randomized
if type is not None:
if 'YES' in constants and 'NO' in constants and type == VarType.String:
return YesNoOutputType(name=name, yes='YES', no='NO')
return YesNoOutputType(name='ans', yes='YES', no='NO')
if 'FIRST' in constants and 'SECOND' in constants and type == VarType.String:
return YesNoOutputType(name=name, yes='FIRST', no='SECOND')
return OneOutputType(name=name, type=type)
return YesNoOutputType(name='ans', yes='FIRST', no='SECOND')
return OneOutputType(name='ans', type=type)

# pattern:
# x y
Expand Down Expand Up @@ -87,11 +86,10 @@ def analyze_output_type(*, output_format: FormatNode, output_variables: Dict[str
if isinstance(item0, ItemNode) and isinstance(item1, NewlineNode) and isinstance(item3, NewlineNode):
if isinstance(item2, LoopNode) and isinstance(item2.body, ItemNode) and item2.size == item0.name and item2.body.indices == [item2.name]:
type = decls[item2.body.name].type
name = 'ans' # item2.body.name may be randomized
subscripted_name = _get_variable(decl=decls[name], indices=item2.body.indices, decls=decls)
subscripted_name = _get_variable(decl=decls[item2.body.name], indices=item2.body.indices, decls=decls)
counter_name = item2.name
if type is not None:
return VectorOutputType(name=name, type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=False)
return VectorOutputType(name='ans', type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=False)

# pattern:
# n a_1 ... a_n
Expand All @@ -102,11 +100,10 @@ def analyze_output_type(*, output_format: FormatNode, output_variables: Dict[str
if isinstance(item0, ItemNode) and isinstance(item2, NewlineNode):
if isinstance(item1, LoopNode) and isinstance(item1.body, ItemNode) and item1.size == item0.name and item1.body.indices == [item1.name]:
type = decls[item1.body.name].type
name = 'ans' # item1.body.name may be randomized
subscripted_name = _get_variable(decl=decls[name], indices=item1.body.indices, decls=decls)
subscripted_name = _get_variable(decl=decls[item1.body.name], indices=item1.body.indices, decls=decls)
counter_name = item1.name
if type is not None:
return VectorOutputType(name=name, type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=False, print_newline_after_item=False)
return VectorOutputType(name='ans', type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=False, print_newline_after_item=False)

# pattern:
# n
Expand All @@ -123,10 +120,9 @@ def analyze_output_type(*, output_format: FormatNode, output_variables: Dict[str
item4 = node.items[1]
if isinstance(item3, ItemNode) and isinstance(item4, NewlineNode) and item2.size == item0.name and item3.indices == [item0.name]:
type = decls[item3.name].type
name = 'ans' # item3.name may be randomized
subscripted_name = _get_variable(decl=decls[name], indices=item3.indices, decls=decls)
subscripted_name = _get_variable(decl=decls[item3.name], indices=item3.indices, decls=decls)
counter_name = item2.name
if type is not None:
return VectorOutputType(name=name, type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=True)
return VectorOutputType(name='ans', type=type, subscripted_name=subscripted_name, counter_name=counter_name, print_size=True, print_newline_after_size=True, print_newline_after_item=True)

return None
122 changes: 110 additions & 12 deletions onlinejudge_template/analyzer/simple_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
"""

import functools
import random
import itertools
import re
import string
from logging import getLogger
from typing import *

Expand Down Expand Up @@ -229,6 +228,9 @@ def _make_tree_pattern_dfs(node: FormatNode) -> Tuple[FormatNode, bool]:


def _make_tree_patterns(patterns: List[FormatNode]) -> List[FormatNode]:
"""_make_tree_patterns detects patterns which have the variable `n` and arrays with lentgh `n`, and replaces the length of arrays with `n - 1`.
"""

tree_patterns = []
for pattern in patterns:
pattern, replaced = _make_tree_pattern_dfs(pattern)
Expand All @@ -239,6 +241,9 @@ def _make_tree_patterns(patterns: List[FormatNode]) -> List[FormatNode]:

@functools.lru_cache(maxsize=None)
def list_all_patterns() -> List[Tuple[FormatNode, Dict[str, VarDecl]]]:
"""list_all_patterns lists all pre-defined petterns.
"""

patterns: List[FormatNode] = [
*_simple_patterns,
*_vertical_simple_patterns,
Expand All @@ -258,49 +263,142 @@ def list_all_patterns() -> List[Tuple[FormatNode, Dict[str, VarDecl]]]:
return results


def _randomize_variables_dfs(node: FormatNode, *, mapping: Dict[str, str]) -> FormatNode:
def list_output_patterns_depending_input_variable(n: str) -> List[FormatNode]:
"""list_output_patterns_depending_input_variable lists output patterns which depend input patterns.
:param n: is the name of the variable which represents the length of array.
"""

assert n not in ('ans', 'i')
vector_pattern = SequenceNode(items=[
LoopNode(name='i', size=n, body=ItemNode(name='ans', indices=['i'])),
NewlineNode(),
])
vertical_vector_pattern = LoopNode(name='i', size=n, body=SequenceNode(items=[
ItemNode(name='ans', indices=['i']),
NewlineNode(),
]))
all_patterns = [vector_pattern, vertical_vector_pattern]
return all_patterns


def _rename_variables_if_conflicts_dfs(node: FormatNode, *, mapping: Dict[str, str], env: Dict[str, VarDecl]) -> FormatNode:
def rename(s: str) -> str:
for a, b in mapping.items():
s = re.sub(r'\b' + re.escape(a) + r'\b', b, s)
return s

if isinstance(node, ItemNode):
assert node.name not in mapping
mapping[node.name] = random.choice(string.ascii_lowercase) + random.choice(string.ascii_lowercase) + random.choice(string.ascii_lowercase)
indices = list(map(rename, node.indices))
return ItemNode(name=mapping[node.name], indices=indices)
assert node.name not in mapping # because there are only such patterns
if node.name not in env:
return node
else:
for i in itertools.count(1):
new_name = node.name + str(i)
if new_name not in env:
mapping[node.name] = new_name
break
indices = list(map(rename, node.indices))
return ItemNode(name=mapping[node.name], indices=indices)

elif isinstance(node, NewlineNode):
return node

elif isinstance(node, SequenceNode):
items: List[FormatNode] = []
for item in node.items:
items.append(_randomize_variables_dfs(item, mapping=mapping))
items.append(_rename_variables_if_conflicts_dfs(item, mapping=mapping, env=env))
return SequenceNode(items=items)

elif isinstance(node, LoopNode):
body = _randomize_variables_dfs(node.body, mapping=mapping)
body = _rename_variables_if_conflicts_dfs(node.body, mapping=mapping, env=env)
return LoopNode(name=node.name, size=rename(node.size), body=body)

else:
assert False


def randomize_variables(node: FormatNode) -> FormatNode:
return _randomize_variables_dfs(node, mapping={})
def rename_variables_if_conflicts(node: FormatNode, *, env: Dict[str, VarDecl]) -> FormatNode:
return _rename_variables_if_conflicts_dfs(node, mapping={}, env=env)


def guess_format_with_pattern_matching(*, instances: List[bytes]) -> Optional[FormatNode]:
"""guess_format_with_pattern_matching guesses a format tree from the strings which match with the format tree, i.e. sample cases.
:param instances: are sample cases.
"""

found: List[FormatNode] = []

# patterns without variables in the input format
for pattern, variables in list_all_patterns():
pattern = rename_variables_if_conflicts(pattern, env={})
try:
for data in instances:
match_format(pattern, data.decode(), variables=variables)
except FormatMatchError:
pass
else:
logger.debug('simple pattern found: %s', pattern)
found.append(pattern)

if len(found) == 1:
return found[0]
else:
return None


def guess_output_format_with_pattern_matching_using_input_format(*, instances: List[SampleCase], input_format: FormatNode, input_variables: Dict[str, VarDecl]) -> Optional[FormatNode]:
"""guess_output_format_with_pattern_matching_using_input_format
:param instances: are sample cases.
:param input_format:
:param input_variables:
"""

found: List[FormatNode] = []

# patterns without variables in the input format
for pattern, variables in list_all_patterns():
try:
for data in instances:
match_format(pattern, data.output.decode(), variables=variables)
except FormatMatchError:
pass
else:
pattern = rename_variables_if_conflicts(pattern, env=input_variables)
logger.debug('simple output pattern found without input variables: %s', pattern)
found.append(pattern)

# patterns with variables in the input format
for name in ('n', 'N', 'm', 'M', 't', 'T'):
if name in input_variables and input_variables[name].type in (VarType.IndexInt, VarType.ValueInt):
env = dict(input_variables)
env.pop(name)
for pattern in list_output_patterns_depending_input_variable(name):

# prepare pattern
pattern = rename_variables_if_conflicts(pattern, env=env)
try:
variables = onlinejudge_template.analyzer.variables.list_declared_variables(pattern)
except onlinejudge_template.analyzer.variables.DeclaredVariablesError:
assert False
assert name not in variables

# try matching
try:
for data in instances:
input_values = match_format(input_format, data.input.decode(), variables=input_variables)
values = {name: input_values[name]} # hide variables other than the `name`
match_format(pattern, data.output.decode(), variables=variables, values=values)
except FormatMatchError as e:
logger.exception(e)
pass
else:
logger.debug('simple output pattern found with input variables: %s', pattern)
found.append(pattern)

if len(found) == 1:
return randomize_variables(found[0])
return found[0]
else:
return None
36 changes: 36 additions & 0 deletions tests/analyzer_combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

import onlinejudge_template.analyzer.combined as analyzer
from onlinejudge_template.types import *


class TestAnalyzerCombined(unittest.TestCase):
"""TestAnalyzerCombinedCodeforces is a class for integration tests about analyzers (without network access).
"""
def test_output_format_depending_input_format(self) -> None:
resources = AnalyzerResources(
url='https://atcoder.jp/contests/arc093/tasks/arc093_a',
html=b'...skipped...',
input_format_string='N\r\nA_1 A_2 ... A_N\r\n',
output_format_string=None,
sample_cases=[
SampleCase(input=b'3\n3 5 -1\n', output=b'12\n8\n10\n'),
SampleCase(input=b'5\n1 1 1 2 0\n', output=b'4\n4\n4\n2\n4\n'),
SampleCase(input=b'6\n-679 -2409 -3258 3095 -3291 -4462\n', output=b'21630\n21630\n19932\n8924\n21630\n19288\n'),
],
)

input_format = SequenceNode(items=[
ItemNode(indices=[], name='N'),
NewlineNode(),
LoopNode(body=ItemNode(indices=['i + 1'], name='A'), name='i', size='N'),
NewlineNode(),
])
output_format = LoopNode(body=SequenceNode(items=[
ItemNode(indices=['i'], name='ans'),
NewlineNode(),
]), name='i', size='N')

analyzed = analyzer.run(resources)
self.assertEqual(str(analyzed.input_format), str(input_format))
self.assertEqual(str(analyzed.output_format), str(output_format))

0 comments on commit e03df1e

Please sign in to comment.