Skip to content

Commit

Permalink
cuda dep
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Nov 16, 2024
1 parent 8146053 commit 6e078a9
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions corkit/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,32 @@ def inverse(x, forward_eq, init_shape):
return forward, inverse


def cross_model_reconstruction():
return torch.jit.load(os.path.join(DEFAULT_SAVE_DIR, "models/cross.pt"))
def load_model(path: str):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
match device:
case 'cuda':
return torch.jit.load(path)
case 'cpu':
return torch.jit.load(path, map_location = torch.device('cpu'))

def cross_model_reconstruction():
path: str = os.path.join(DEFAULT_SAVE_DIR, "models/cross.pt")
return load_model(path)

def fourier_model_reconstruction():
return torch.jit.load(os.path.join(DEFAULT_SAVE_DIR, "models/fourier.pt"))

path: str = os.path.join(DEFAULT_SAVE_DIR, "models/fourier.pt")
return load_model(path)

def normal_model_reconstruction():
return torch.jit.load(os.path.join(DEFAULT_SAVE_DIR, "models/partial_conv.pt"))

path: str = os.path.join(DEFAULT_SAVE_DIR, "models/partial_conv.pt")
return load_model(path)

def dl_image(model, img, bkg, forward_transform, inverse_transform):
init_shape = img.shape
x, forward_eq, mask = forward_transform(img.astype(np.float32), bkg)

if len(np.where(mask == 0.)[0]) > 32*32:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
x, _ = model(x.view(1, 1, 1024,1024).to(device).float(), mask.view(1, 1, 1024,1024).to(device).float())
x = interpolate(x, size = init_shape)
x = inverse_transform(x, forward_eq, init_shape)
Expand Down

0 comments on commit 6e078a9

Please sign in to comment.