diff --git a/CHANGELOG.md b/CHANGELOG.md index 6076b631..dd54b8e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Fixed - Correctly refer to input peak files by their full file path. +- Specifying custom residues to retrain Casanovo is now possible. ## [3.3.0] - 2023-04-04 diff --git a/casanovo/config.py b/casanovo/config.py index 4dc93c26..0dfdaf67 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -50,6 +50,7 @@ class Config: dropout=float, dim_intensity=int, max_length=int, + residues=dict, n_log=int, tb_summarywriter=str, warmup_iters=int, diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 8282e367..89b37e42 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -1,6 +1,4 @@ """Test configuration loading""" -import pytest - from casanovo.config import Config @@ -17,11 +15,23 @@ def test_override(tmp_path): """Test overriding the default""" yml = tmp_path / "test.yml" with yml.open("w+") as f_out: - f_out.write("random_seed: 42\ntop_match: 3") + f_out.write( + """random_seed: 42 +top_match: 3 +residues: + W: 1 + O: 2 + U: 3 + T: 4 +""" + ) config = Config(yml) assert config.random_seed == 42 assert config["random_seed"] == 42 assert not config.no_gpu assert config.top_match == 3 + assert len(config.residues) == 4 + for i, residue in enumerate("WOUT", 1): + assert config["residues"][residue] == i assert config.file == str(yml)