Skip to content

Commit

Permalink
Merge pull request #113 from 3dem/revert-hmm-cleaning
Browse files Browse the repository at this point in the history
Fix issues with extra chain breaks
  • Loading branch information
jamaliki authored Oct 24, 2024
2 parents 73e80f9 + c1d9a65 commit 5e150c9
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion model_angelo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
"""


__version__ = "1.0.12"
__version__ = "1.0.13"
1 change: 0 additions & 1 deletion model_angelo/gnn/flood_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def final_results_to_cif(
final_results["entropy_score"] = local_confidence_score_sigmoid(
- aa_entropy, best_value=5.0, worst_value=2.0, mid_point=3.0,
)
final_results["aa_logits"] /= temperature

torsion_angles = select_torsion_angles(
torch.from_numpy(final_results["pred_torsions"][existence_mask]), aatype=aatype
Expand Down
12 changes: 12 additions & 0 deletions model_angelo/models/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,15 @@ def __init__(self):
def forward(self, x, y):
s = torch.sigmoid(self.gate)
return s * x + (1 - s) * y


class Upsample(nn.Module):
def __init__(self, scale_factor, mode="nearest"):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode

def forward(self, x):
return F.interpolate(
x, scale_factor=self.scale_factor, mode=self.mode
)
6 changes: 3 additions & 3 deletions model_angelo/models/multi_gpu_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def cast_dict_to_full(dictionary):

def init_model(model_definition_path: str, state_dict_path: str, device: str) -> nn.Module:
model = get_model_from_file(model_definition_path).eval()
checkpoint = torch.load(state_dict_path, map_location="cpu")
checkpoint = torch.load(state_dict_path, map_location="cpu", weights_only=False)
if "model" not in checkpoint:
model.load_state_dict(checkpoint)
else:
Expand Down Expand Up @@ -93,7 +93,7 @@ def run_inference(
inference_data = input_queue.get()
if inference_data.status != 1:
break
with torch.cuda.amp.autocast(dtype=dtype):
with torch.autocast(device_type="cuda", dtype=dtype):
output = model(**inference_data.data)
output = output.to("cpu").to(torch.float32)
output_queue.put(output)
Expand Down Expand Up @@ -155,7 +155,7 @@ def forward(self, data_list: List) -> List:
InferenceData(data=send_dict_to_device(data, device), status=1)
)
else:
with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad():
with torch.autocast(device_type="cuda", dtype=self.dtype), torch.no_grad():
output_list.append(
self.model(**send_dict_to_device(data, device)).to("cpu").to(torch.float32)
)
Expand Down
3 changes: 2 additions & 1 deletion model_angelo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
NormalizeStd,
RegularBlock,
ResBlock,
Upsample
)


Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
self.main_layers.append(pooling_class(2, stride=2))
for i in range(self.num_blocks):
if i == self.num_blocks - 2 and downsample_x:
self.main_layers.append(nn.Upsample(scale_factor=2))
self.main_layers.append(Upsample(scale_factor=2))
self.main_layers.append(
ResBlock(
in_channels=self.g_width,
Expand Down
3 changes: 1 addition & 2 deletions model_angelo/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def checkpoint_load_latest(
log_dir: str, device: torch.DeviceObjType, match_model: bool = True, **kwargs
) -> int:
checkpoint_to_load, step_num = find_latest_checkpoint(log_dir)
state_dicts = torch.load(checkpoint_to_load, map_location=device)
state_dicts = torch.load(checkpoint_to_load, map_location=device, weights_only=False)
if match_model:
warnings.warn(
"In checkpoint_load_latest, match_model is set to True. "
Expand Down Expand Up @@ -571,4 +571,3 @@ def compile_if_possible(module: nn.Module) -> nn.Module:
module = torch.compile(module)
return module


3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"mrcfile",
"pandas",
"fair-esm==1.0.3",
"pyhmmer>=0.10.1",
"pyhmmer==0.7.1",
"loguru",
"numpy==1.21.*",
],
)

0 comments on commit 5e150c9

Please sign in to comment.