Skip to content

Commit

Permalink
move config to strings
Browse files Browse the repository at this point in the history
  • Loading branch information
jgreener64 committed Aug 23, 2024
1 parent b600df0 commit 0be7999
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 80 deletions.
89 changes: 86 additions & 3 deletions progres/chainsaw/get_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import hashlib
import json
import logging
import os
import sys
Expand All @@ -17,10 +18,92 @@
from progres.chainsaw.src.domain_assignment.util import convert_domain_dict_strings
from progres.chainsaw.src.factories import pairwise_predictor
from progres.chainsaw.src.models.results import PredictionResult
from .src.utils import common as common_utils

LOG = logging.getLogger(__name__)

config_str = """
{
"experiment_name": "C_mul65_do30",
"experiment_group": "cath_new",
"lr": 0.0002,
"weight_decay": 0.001,
"val_freq": 1,
"epochs": 15,
"lr_scheduler": {
"type": "exponential",
"gamma": 0.9
},
"accumulation_steps": 16,
"data": {
"splits_file": "splits_new_cath_featurized.json",
"validation_splits": [
"validation",
"test"
],
"crop_size": null,
"crop_type": null,
"batch_size": 1,
"feature_dir": "../features/new_cath/2d_features",
"label_dir": "../features/new_cath/pairwise",
"chains_csv": null,
"evaluate_test": false,
"eval_casp_10_plus": false,
"remove_missing_residues": false,
"using_alphafold_features": false,
"recycling": false,
"add_padding_mask": false,
"training_exclusion_json": null,
"multi_proportion": 0.65,
"train_ids": "splits_new_cath_featurized.json",
"exclude_test_topology": false,
"cluster_sampling_training": true,
"dist_transform": "unidoc_exponent",
"distance_denominator": 10,
"merizo_train_data": false,
"redundancy_level": "S60_comb"
},
"learner": {
"uncertainty_model": false,
"save_every_epoch": true,
"model": {
"type": "trrosetta",
"kwargs": {
"filters": 32,
"kernel": 3,
"num_layers": 31,
"in_channels": 5,
"dropout": 0.3,
"symmetrise_output": true
}
},
"assignment": {
"type": "sparse_lowrank",
"kwargs": {
"N_iters": 3,
"K_init": 4,
"linker_threshold": 30
}
},
"max_recycles": 0,
"save_val_best": true,
"x_has_padding_mask": false
},
"num_trainable_params": 577889
}
"""

feature_config_str = """
{
"description": "alpha distance only, keep the start and end boundaries in one channel but use -1",
"alpha_distances": true,
"beta_distances": false,
"ss_bounds": true,
"negative_ss_end": true,
"separate_channel_ss_start_end": false,
"same_channel_boundaries_and_ss": false
}
"""

def setup_logging(loglevel):
# log all messages to stderr so results can be sent to stdout
logging.basicConfig(level=loglevel,
Expand All @@ -35,8 +118,8 @@ def load_model(*,
min_domain_length: int = 30,
post_process_domains: bool = True,
device: str = "cpu"):
config = common_utils.load_json(os.path.join(model_dir, "config.json"))
feature_config = common_utils.load_json(os.path.join(model_dir, "feature_config.json"))
config = json.loads(config_str)
feature_config = json.loads(feature_config_str)
config["learner"]["remove_disordered_domain_threshold"] = remove_disordered_domain_threshold
config["learner"]["post_process_domains"] = post_process_domains
config["learner"]["min_ss_components"] = min_ss_components
Expand Down
68 changes: 0 additions & 68 deletions progres/chainsaw/model_v3/config.json

This file was deleted.

9 changes: 0 additions & 9 deletions progres/chainsaw/model_v3/feature_config.json

This file was deleted.

0 comments on commit 0be7999

Please sign in to comment.