Skip to content

Commit

Permalink
fix some rules engine bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
azuline committed Nov 1, 2023
1 parent 052ec37 commit d05bcc4
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 52 deletions.
2 changes: 1 addition & 1 deletion rose/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def metadata() -> None:
@click.pass_obj
def run_stored_rules(ctx: Context, yes: bool) -> None:
"""Run the metadata rules stored in the config"""
execute_stored_metadata_rules(ctx.config, confirm_yes=yes)
execute_stored_metadata_rules(ctx.config, confirm_yes=not yes)


@cli.command()
Expand Down
33 changes: 29 additions & 4 deletions rose/rule_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ class InvalidRuleSpecError(RoseError):
]


SINGLE_VALUE_TAGS: list[Tag] = [
"tracktitle",
"year",
"tracknumber",
"discnumber",
"albumtitle",
"releasetype",
]

MULTI_VALUE_TAGS: list[Tag] = [
"genre",
"label",
"artist",
]


@dataclass
class ReplaceAction:
"""
Expand Down Expand Up @@ -138,12 +154,10 @@ def parse_dict(cls, data: dict[str, Any]) -> MetadataRule:
f"Key `tags` must be a string or a list of strings: got {type(tags)}"
)
for t in tags:
if t not in ALL_TAGS and t != "*":
if t not in ALL_TAGS:
raise InvalidRuleSpecError(
f"Key `tags`'s values must be one of *, {', '.join(ALL_TAGS)}: got {t}"
f"Key `tags`'s values must be one of {', '.join(ALL_TAGS)}: got {t}"
)
if any(t == "*" for t in tags):
tags = ALL_TAGS

try:
matcher = data["matcher"]
Expand Down Expand Up @@ -228,6 +242,17 @@ def parse_dict(cls, data: dict[str, Any]) -> MetadataRule:
f"got {action_kind}"
)

# Validate that the action kind and tags are acceptable. Mainly that we are not calling
# `replaceall` and `splitall` on single-valued tags.
multi_value_action = action_kind == "replaceall" or action_kind == "spliton"
if multi_value_action:
single_valued_tags = [t for t in tags if t in SINGLE_VALUE_TAGS]
if single_valued_tags:
raise InvalidRuleSpecError(
f"Single valued tags {', '.join(single_valued_tags)} cannot be modified by "
f"multi-value action {action_kind}"
)

return cls(
tags=tags,
matcher=matcher,
Expand Down
22 changes: 4 additions & 18 deletions rose/rule_parser_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re

from rose.rule_parser import (
ALL_TAGS,
DeleteAction,
MetadataRule,
ReplaceAction,
Expand Down Expand Up @@ -54,28 +53,15 @@ def test_rule_parser() -> None:
action=ReplaceAction(replacement="hihi"),
)

# Test all tag expansion
assert MetadataRule.parse_dict(
{
"tags": "*",
"matcher": "lala",
"action": {"kind": "replace", "replacement": "hihi"},
}
) == MetadataRule(
tags=ALL_TAGS,
matcher="lala",
action=ReplaceAction(replacement="hihi"),
)

# Test replaceall
assert MetadataRule.parse_dict(
{
"tags": "tracktitle",
"tags": "genre",
"matcher": "lala",
"action": {"kind": "replaceall", "replacement": ["hihi"]},
}
) == MetadataRule(
tags=["tracktitle"],
tags=["genre"],
matcher="lala",
action=ReplaceAllAction(replacement=["hihi"]),
)
Expand All @@ -96,12 +82,12 @@ def test_rule_parser() -> None:
# Test spliton
assert MetadataRule.parse_dict(
{
"tags": "tracktitle",
"tags": "genre",
"matcher": "lala",
"action": {"kind": "spliton", "delimiter": "."},
}
) == MetadataRule(
tags=["tracktitle"],
tags=["genre"],
matcher="lala",
action=SplitAction(delimiter="."),
)
Expand Down
14 changes: 7 additions & 7 deletions rose/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def matches_rule(x: str) -> bool:
LEFT JOIN releases_labels rl ON rg.release_id = r.id
LEFT JOIN releases_artists ra ON ra.release_id = r.id
LEFT JOIN tracks_artists ta ON ta.track_id = t.id
WHERE 1=1
WHERE false
"""
args: list[str] = []
for field in rule.tags:
Expand All @@ -123,19 +123,17 @@ def matches_rule(x: str) -> bool:
if field == "releasetype":
query += r" OR r.release_type LIKE ? ESCAPE '\'"
args.append(matchsql)
# For genres, labels, and artists, because SQLite lacks arrays, we create a string like
# `\\ val1 \\ val2 \\` and match on `\\ {matcher} \\`.
if field == "genre":
query += r" OR rg.genre LIKE ? ESCAPE '\'"
args.append(rf" \\ {matchsql} \\ ")
args.append(matchsql)
if field == "label":
query += r" OR rl.label LIKE ? ESCAPE '\'"
args.append(rf" \\ {matchsql} \\ ")
args.append(matchsql)
if field == "artist":
query += r" OR ra.artist LIKE ? ESCAPE '\'"
args.append(rf" \\ {matchsql} \\ ")
args.append(matchsql)
query += r" OR ta.artist LIKE ? ESCAPE '\'"
args.append(rf" \\ {matchsql} \\ ")
args.append(matchsql)
query += " ORDER BY t.source_path"
logger.debug(f"Constructed matching query {query} with args {args}")
# And then execute the SQL query. Note that we don't pull the tag values here. This query is
Expand All @@ -144,6 +142,8 @@ def matches_rule(x: str) -> bool:
with connect(c) as conn:
track_paths = [Path(row["source_path"]).resolve() for row in conn.execute(query, args)]
logger.debug(f"Matched {len(track_paths)} tracks from the read cache")
if not track_paths:
return

# Factor out the logic for executing an action on a single-value tag and a multi-value tag.
def execute_single_action(value: str | None) -> str | None:
Expand Down
57 changes: 35 additions & 22 deletions rose/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,35 +107,48 @@ def test_rules_execution_match_superstrict(config: Config, source_dir: Path) ->
assert af.title == "lalala"


def test_all_fields_match(config: Config, source_dir: Path) -> None:
@pytest.mark.parametrize(
"tag",
[
"year",
"tracktitle",
"tracknumber",
"discnumber",
"albumtitle",
"genre",
"label",
"artist",
],
)
def test_all_non_enum_fields_match(config: Config, source_dir: Path, tag: str) -> None:
# Test most fields.
rule = MetadataRule(
tags=[
"year",
"tracktitle",
"tracknumber",
"discnumber",
"albumtitle",
"genre",
"label",
"artist",
],
tags=[tag], # type: ignore
matcher="", # Empty string matches everything.
action=ReplaceAction(replacement="8"),
)
execute_metadata_rule(config, rule, False)
af = AudioTags.from_file(source_dir / "Test Release 1" / "01.m4a")
assert af.title == "8"
assert af.year == 8
assert af.track_number == "8"
assert af.disc_number == "8"
assert af.album == "8"
assert af.genre == ["8", "8"]
assert af.label == ["8"]
assert af.album_artists.main == ["8"]
assert af.artists.main == ["8"]

# And then test release type separately.
if tag == "tracktitle":
assert af.title == "8"
if tag == "year":
assert af.year == 8
if tag == "tracknumber":
assert af.track_number == "8"
if tag == "discnumber":
assert af.disc_number == "8"
if tag == "albumtitle":
assert af.album == "8"
if tag == "genre":
assert af.genre == ["8", "8"]
if tag == "label":
assert af.label == ["8"]
if tag == "artist":
assert af.album_artists.main == ["8"]
assert af.artists.main == ["8"]


def test_releasetype_matches(config: Config, source_dir: Path) -> None:
rule = MetadataRule(
tags=["releasetype"],
matcher="", # Empty string matches everything.
Expand Down

0 comments on commit d05bcc4

Please sign in to comment.