Skip to content

Commit

Permalink
Merge pull request #140 from natefoo/lint-warn-cores-no-mem
Browse files Browse the repository at this point in the history
Add lint warning for tools with `cores` and no `mem`
  • Loading branch information
nuwang authored Nov 18, 2024
2 parents dda0acb + 78f9fb2 commit d835738
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tools:
rank: |
helpers.weighted_random_sampling(candidate_destinations)
toolshed.g2.bx.psu.edu/repos/iuc/mothur_shhh_seqs/mothur_shhh_seqs/.*:
# This is a comment
inherits: wig_to_bigWig
cores: 2
mem: 20
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ tools:
wig_to_bigWig:
mem: 10
toolshed.g2.bx.psu.edu/repos/iuc/mothur_shhh_seqs/mothur_shhh_seqs/.*:
# This is a comment
cores: 2
mem: 20
inherits: wig_to_bigWig
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/linter/linter-invalid-regex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ tools:
cores: 2
params:
native_spec: "--mem {mem} --cores {cores} --gpus {gpus}"
bwa[0-9]++:
bwa[0-9]^++:
gpus: 2

destinations:
Expand Down
29 changes: 29 additions & 0 deletions tests/fixtures/linter/linter-warnings.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
global:
default_inherits: default

tools:
default:
abstract: true
cores: 2
mem: 4
params:
native_spec: "--mem {mem} --cores {cores} --gpus {gpus}"
mem-no-cores-1:
mem: 16
cores-no-mem-1:
cores: 8
cores-no-mem-2:
# noqa: T102
cores: 8
cores-no-mem-3:
# noqa
cores: 8

destinations:
local:
runner: local
max_accepted_cores: 4
max_accepted_mem: 16
scheduling:
prefer:
- general
17 changes: 17 additions & 0 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,23 @@ def test_lint_destination_defines_cores_instead_of_accepted_cores(self):
"working_dest" not in output,
f"Did not expect destination: `working_dest` to be in the output, but found: {output}")

def test_lint_warnings(self):
tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/linter/linter-warnings.yml')
output = self.call_shell_command("tpv", "-vv", "lint", tpv_config)
self.assertTrue(
"T102: The tool named: cores-no-mem-1 sets `cores`" in output,
f"Expected T102 warning for cores-no-mem-1 but output was: {output}")
self.assertFalse(
"T102: The tool named: cores-no-mem-2 sets `cores`" in output,
f"T102 warning for cores-no-mem-2 should be suppressed by noqa but output was: {output}")
self.assertFalse(
"T102: The tool named: cores-no-mem-3 sets `cores`" in output,
f"T102 warning for cores-no-mem-3 should be suppressed by noqa but output was: {output}")
output = self.call_shell_command("tpv", "-vv", "lint", "--ignore=T102", tpv_config)
self.assertFalse(
"T102: The tool named:" in output,
f"T102 warnings should be suppressed by --ignore but output was: {output}")

def test_warn_if_default_inherits_not_marked_abstract(self):
tpv_config = os.path.join(os.path.dirname(__file__),
'fixtures/linter/linter-default-inherits-marked-abstract.yml')
Expand Down
26 changes: 19 additions & 7 deletions tpv/commands/formatter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations
import logging

from ruamel.yaml.comments import CommentedMap, CommentedSeq

from tpv.core import util

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,14 +66,24 @@ def multi_level_dict_sorter(dict_to_sort, sort_order):
"""
if not sort_order:
return dict_to_sort
if isinstance(dict_to_sort, dict):
if isinstance(dict_to_sort, CommentedMap):
sorted_keys = sorted(dict_to_sort or [], key=TPVConfigFormatter.generic_key_sorter(sort_order.keys()))
return {key: TPVConfigFormatter.multi_level_dict_sorter(dict_to_sort.get(key),
sort_order.get(key, {}) or sort_order.get('*', {}))
for key in sorted_keys}
elif isinstance(dict_to_sort, list):
return [TPVConfigFormatter.multi_level_dict_sorter(item, sort_order.get('*', []))
for item in dict_to_sort]
rval = CommentedMap()
for key in sorted_keys:
sorted_value = TPVConfigFormatter.multi_level_dict_sorter(
dict_to_sort.get(key),
sort_order.get(key, {}) or sort_order.get('*', {})
)
rval[key] = sorted_value
rval.ca.items.update(dict_to_sort.ca.items)
return rval
elif isinstance(dict_to_sort, CommentedSeq):
rval = CommentedSeq()
for item in dict_to_sort:
sorted_item = TPVConfigFormatter.multi_level_dict_sorter(item, sort_order.get('*', []))
rval.append(sorted_item)
rval.ca.items.update(dict_to_sort.ca.items)
return rval
else:
return dict_to_sort

Expand Down
53 changes: 39 additions & 14 deletions tpv/commands/linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,58 @@
log = logging.getLogger(__name__)


# Warning codes:
# T101: default inheritance not marked abstract
# T102: entity specifies cores without memory


class TPVLintError(Exception):
pass


class TPVConfigLinter(object):

def __init__(self, url_or_path):
def __init__(self, url_or_path, ignore):
self.url_or_path = url_or_path
self.ignore = ignore or []
self.warnings = []
self.errors = []
self.loader = None

def lint(self):
def load_config(self):
try:
loader = TPVConfigLoader.from_url_or_path(self.url_or_path)
self.loader = TPVConfigLoader.from_url_or_path(self.url_or_path)
except Exception as e:
log.error(f"Linting failed due to syntax errors in yaml file: {e}")
raise TPVLintError("Linting failed due to syntax errors in yaml file: ") from e
default_inherits = loader.global_settings.get('default_inherits')
for tool_regex, tool in loader.tools.items():

def add_warning(self, entity, code, message):
if code not in self.ignore and not self.loader.check_noqa(entity, code):
self.warnings.append((code, message))

def lint(self):
if self.loader is None:
self.load_config()
default_inherits = self.loader.global_settings.get('default_inherits')
for tool_regex, tool in self.loader.tools.items():
try:
re.compile(tool_regex)
except re.error:
self.errors.append(f"Failed to compile regex: {tool_regex}")
if default_inherits == tool.id:
self.warnings.append(
if default_inherits == tool.id and not tool.abstract:
self.add_warning(
tool,
"T101",
f"The tool named: {default_inherits} is marked globally as the tool to inherit from "
"by default. You may want to mark it as abstract if it is not an actual tool and it "
"will be excluded from scheduling decisions.")
for destination in loader.destinations.values():
if tool.cores and not tool.mem:
self.add_warning(
tool,
"T102",
f"The tool named: {tool_regex} sets `cores` but not `mem`. This can lead to "
"unexpected memory usage since memory is typically a multiplier of cores.")
for destination in self.loader.destinations.values():
if not destination.runner and not destination.abstract:
self.errors.append(f"Destination '{destination.id}' does not define the runner parameter. "
"The runner parameter is mandatory.")
Expand All @@ -46,19 +69,21 @@ def lint(self):
f"max_accepted_cores/mem/gpus. This is probably an error. If you're migrating from an older "
f"version of TPV, the destination properties for cores/mem/gpus have been superseded by the "
f"max_accepted_cores/mem/gpus property. Simply renaming them will give you the same functionality.")
if default_inherits == destination.id:
self.warnings.append(
if default_inherits == destination.id and not destination.abstract:
self.add_warning(
destination,
"T101",
f"The destination named: {default_inherits} is marked globally as the destination to inherit from "
"by default. You may want to mark it as abstract if it is not meant to be dispatched to, and it "
"will be excluded from scheduling decisions.")
if self.warnings:
for w in self.warnings:
log.warning(w)
for code, message in self.warnings:
log.warning(f"{code}: {message}")
if self.errors:
for e in self.errors:
log.error(e)
raise TPVLintError(f"The following errors occurred during linting: {self.errors}")

@staticmethod
def from_url_or_path(url_or_path: str):
return TPVConfigLinter(url_or_path)
def from_url_or_path(url_or_path: str, ignore=None):
return TPVConfigLinter(url_or_path, ignore=ignore)
13 changes: 10 additions & 3 deletions tpv/commands/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def repr_none(dumper: RoundTripRepresenter, data):

def tpv_lint_config_file(args):
try:
tpv_linter = TPVConfigLinter.from_url_or_path(args.config)
ignore = []
if args.ignore is not None:
ignore = [x.strip() for x in args.ignore.split(",")]
tpv_linter = TPVConfigLinter.from_url_or_path(args.config, ignore)
tpv_linter.lint()
log.info("lint successful.")
return 0
Expand Down Expand Up @@ -74,6 +77,9 @@ def create_parser():
'lint',
help='loads a TPV configuration file and checks it for syntax errors',
description="The linter will check yaml syntax and compile python code blocks")
lint_parser.add_argument(
'--ignore', type=str,
help="Comma-separated list of lint error and warning codes to ignore")
lint_parser.add_argument(
'config', type=str,
help="Path to the TPV config file to lint. Can be a local path or http url.")
Expand Down Expand Up @@ -120,14 +126,15 @@ def configure_logging(verbosity_count):
# or basicConfig persists
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
level = max(4 - verbosity_count, 1) * 10
# set global logging level
logging.basicConfig(
stream=sys.stdout,
level=logging.DEBUG if verbosity_count > 3 else logging.ERROR,
level=level,
format='%(levelname)-5s: %(name)s: %(message)s')
# Set client log level
if verbosity_count:
log.setLevel(max(4 - verbosity_count, 1) * 10)
log.setLevel(level)
else:
log.setLevel(logging.INFO)

Expand Down
55 changes: 50 additions & 5 deletions tpv/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ast
import functools
import logging
import re

from . import helpers
from . import util
Expand All @@ -10,6 +11,9 @@
log = logging.getLogger(__name__)


NOQA_RE = re.compile(r"#\s*noqa:\s*([A-Z0-9, ]+)?")


class InvalidParentException(Exception):
pass

Expand All @@ -19,6 +23,7 @@ class TPVConfigLoader(object):
def __init__(self, tpv_config: dict):
self.compile_code_block = functools.lru_cache(maxsize=None)(self.__compile_code_block)
self.global_settings = tpv_config.get('global', {})
self.noqa = {'tools': {}, 'users': {}, 'roles': {}, 'destinations': {}}
entities = self.load_entities(tpv_config)
self.tools = entities.get('tools')
self.users = entities.get('users')
Expand Down Expand Up @@ -69,7 +74,28 @@ def recompute_inheritance(self, entities: dict[str, Entity]):
for key, entity in entities.items():
entities[key] = self.process_inheritance(entities, entity)

def validate_entities(self, entity_class: type, entity_list: dict) -> dict:
def get_noqa_codes(self, entity_comments: list) -> (bool, set[str] | None):
comments = []
if entity_comments and len(entity_comments) == 4 and entity_comments[3]:
comments.extend([x.value.strip() for x in entity_comments[3]])

for comment in comments:
match = re.match(r"#\s*noqa:?\s*([A-Z0-9, ]+)?", comment)
if match:
codes = match.group(1)
# Return a set of codes or None if `# noqa` with no codes
return (True, set(code.strip() for code in codes.split(',')) if codes else None)

return (False, None)

def store_noqa_codes(self, entity_list: dict, entity_id: str, noqa_dict: dict):
if hasattr(entity_list, "ca"):
entity_comments = entity_list.ca.items.get(entity_id)
noqa, noqa_codes = self.get_noqa_codes(entity_comments)
if noqa:
noqa_dict[entity_id] = noqa_codes

def validate_entities(self, entity_class: type, entity_list: dict, noqa_dict: dict) -> dict:
# This code relies on dict ordering guarantees provided since python 3.6
validated = {}
for entity_id, entity_dict in entity_list.items():
Expand All @@ -81,15 +107,19 @@ def validate_entities(self, entity_class: type, entity_list: dict) -> dict:
except Exception:
log.exception(f"Could not load entity of type: {entity_class} with data: {entity_dict}")
raise
self.store_noqa_codes(entity_list, entity_id, noqa_dict)
self.recompute_inheritance(validated)
return validated

def load_entities(self, tpv_config: dict) -> dict:
validated = {
'tools': self.validate_entities(Tool, tpv_config.get('tools', {})),
'users': self.validate_entities(User, tpv_config.get('users', {})),
'roles': self.validate_entities(Role, tpv_config.get('roles', {})),
'destinations': self.validate_entities(Destination, tpv_config.get('destinations', {}))
'tools': self.validate_entities(Tool, tpv_config.get('tools', {}), self.noqa['tools']),
'users': self.validate_entities(User, tpv_config.get('users', {}), self.noqa['users']),
'roles': self.validate_entities(Role, tpv_config.get('roles', {}), self.noqa['roles']),
'destinations': self.validate_entities(
Destination,
tpv_config.get('destinations', {}),
self.noqa['destinations'])
}
return validated

Expand Down Expand Up @@ -118,6 +148,21 @@ def merge_loader(self, loader: TPVConfigLoader):
self.inherit_existing_entities(self.roles, loader.roles)
self.inherit_existing_entities(self.destinations, loader.destinations)

def check_noqa(self, entity: Entity, code: str) -> bool:
if type(entity) is Tool:
noqa = self.noqa['tools']
elif type(entity) is User:
noqa = self.noqa['users']
elif type(entity) is Role:
noqa = self.noqa['roles']
elif type(entity) is Destination:
noqa = self.noqa['destinations']
else:
raise RuntimeError(f"Unknown entity type: {entity}")
if entity.id in noqa and (noqa[entity.id] is None or code in noqa[entity.id]):
return True
return False

@staticmethod
def from_url_or_path(url_or_path: str):
tpv_config = util.load_yaml_from_url_or_path(url_or_path)
Expand Down
2 changes: 1 addition & 1 deletion tpv/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def load_yaml_from_url_or_path(url_or_path: str):
yaml = ruamel.yaml.YAML(typ='safe')
yaml = ruamel.yaml.YAML(typ="rt")
if os.path.isfile(url_or_path):
with open(url_or_path, 'r') as f:
return yaml.load(f)
Expand Down

0 comments on commit d835738

Please sign in to comment.