diff --git a/model_angelo/__init__.py b/model_angelo/__init__.py index 851972e..f485114 100644 --- a/model_angelo/__init__.py +++ b/model_angelo/__init__.py @@ -5,4 +5,4 @@ """ -__version__ = "1.0.12" +__version__ = "1.0.13" diff --git a/model_angelo/gnn/flood_fill.py b/model_angelo/gnn/flood_fill.py index a7a23a0..1832d96 100644 --- a/model_angelo/gnn/flood_fill.py +++ b/model_angelo/gnn/flood_fill.py @@ -194,7 +194,6 @@ def final_results_to_cif( final_results["entropy_score"] = local_confidence_score_sigmoid( - aa_entropy, best_value=5.0, worst_value=2.0, mid_point=3.0, ) - final_results["aa_logits"] /= temperature torsion_angles = select_torsion_angles( torch.from_numpy(final_results["pred_torsions"][existence_mask]), aatype=aatype diff --git a/model_angelo/models/common_modules.py b/model_angelo/models/common_modules.py index 8bde60e..ade3e38 100644 --- a/model_angelo/models/common_modules.py +++ b/model_angelo/models/common_modules.py @@ -213,3 +213,15 @@ def __init__(self): def forward(self, x, y): s = torch.sigmoid(self.gate) return s * x + (1 - s) * y + + +class Upsample(nn.Module): + def __init__(self, scale_factor, mode="nearest"): + super().__init__() + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, x): + return F.interpolate( + x, scale_factor=self.scale_factor, mode=self.mode + ) diff --git a/model_angelo/models/multi_gpu_wrapper.py b/model_angelo/models/multi_gpu_wrapper.py index 4f66e4c..6b18b5c 100644 --- a/model_angelo/models/multi_gpu_wrapper.py +++ b/model_angelo/models/multi_gpu_wrapper.py @@ -60,7 +60,7 @@ def cast_dict_to_full(dictionary): def init_model(model_definition_path: str, state_dict_path: str, device: str) -> nn.Module: model = get_model_from_file(model_definition_path).eval() - checkpoint = torch.load(state_dict_path, map_location="cpu") + checkpoint = torch.load(state_dict_path, map_location="cpu", weights_only=False) if "model" not in checkpoint: model.load_state_dict(checkpoint) else: @@ -93,7 +93,7 @@ def run_inference( inference_data = input_queue.get() if inference_data.status != 1: break - with torch.cuda.amp.autocast(dtype=dtype): + with torch.autocast(device_type="cuda", dtype=dtype): output = model(**inference_data.data) output = output.to("cpu").to(torch.float32) output_queue.put(output) @@ -155,7 +155,7 @@ def forward(self, data_list: List) -> List: InferenceData(data=send_dict_to_device(data, device), status=1) ) else: - with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=self.dtype), torch.no_grad(): output_list.append( self.model(**send_dict_to_device(data, device)).to("cpu").to(torch.float32) ) diff --git a/model_angelo/models/resnet.py b/model_angelo/models/resnet.py index 939b9b3..95d2654 100644 --- a/model_angelo/models/resnet.py +++ b/model_angelo/models/resnet.py @@ -9,6 +9,7 @@ NormalizeStd, RegularBlock, ResBlock, + Upsample ) @@ -62,7 +63,7 @@ def __init__( self.main_layers.append(pooling_class(2, stride=2)) for i in range(self.num_blocks): if i == self.num_blocks - 2 and downsample_x: - self.main_layers.append(nn.Upsample(scale_factor=2)) + self.main_layers.append(Upsample(scale_factor=2)) self.main_layers.append( ResBlock( in_channels=self.g_width, diff --git a/model_angelo/utils/torch_utils.py b/model_angelo/utils/torch_utils.py index d337763..cbea59a 100644 --- a/model_angelo/utils/torch_utils.py +++ b/model_angelo/utils/torch_utils.py @@ -76,7 +76,7 @@ def checkpoint_load_latest( log_dir: str, device: torch.DeviceObjType, match_model: bool = True, **kwargs ) -> int: checkpoint_to_load, step_num = find_latest_checkpoint(log_dir) - state_dicts = torch.load(checkpoint_to_load, map_location=device) + state_dicts = torch.load(checkpoint_to_load, map_location=device, weights_only=False) if match_model: warnings.warn( "In checkpoint_load_latest, match_model is set to True. " @@ -571,4 +571,3 @@ def compile_if_possible(module: nn.Module) -> nn.Module: module = torch.compile(module) return module - diff --git a/setup.py b/setup.py index 7e4dfc0..94b9b3a 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,8 @@ "mrcfile", "pandas", "fair-esm==1.0.3", - "pyhmmer>=0.10.1", + "pyhmmer==0.7.1", "loguru", + "numpy==1.21.*", ], )