Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Dynamic UNet with ResNet18 backend comaptibility in inference and updated libraries to support latest pytorch and fastai #13

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions Image Colorization Tutorial/infer_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
## Update: 11th Jun, 2024
## this file is supposed to give you a general idea on how to
## use the pre-trained model for colorizing B&W images. This
## file still needs development.

import PIL
import torch
from matplotlib import pyplot as plt
from torchvision import transforms

from models import MainModel, build_res_unet
from utils import lab_to_rgb

if __name__ == '__main__':
model = MainModel()
# You first need to download the final_model_weights.pt file from my drive
# using the command: gdown --id 1lR6DcS4m5InSbZ5y59zkH2mHt_4RQ2KV
# You need to train the Dynamic UNet generator and save the weights in a file resnet_model_weights.pt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(
torch.load(
"resnet_model_weights.pt",
map_location=device
)
)
model = MainModel(net_G=net_G)
model.load_state_dict(
torch.load(
"final_model_weights.pt",
map_location=device
)
)
path = "Path to your black and white image"
img = PIL.Image.open(path)
img = img.resize((256, 256))
# to make it between -1 and 1
img = transforms.ToTensor()(img)[:1] * 2. - 1.
model.eval()
with torch.no_grad():
preds = model.net_G(img.unsqueeze(0).to(device))
colorized = lab_to_rgb(img.unsqueeze(0), preds.cpu())[0]
plt.imshow(colorized)
13 changes: 12 additions & 1 deletion Image Colorization Tutorial/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
from torch import nn, optim
from loss import GANLoss
from fastai.vision.learner import create_body
from torchvision.models import resnet18
from fastai.vision.models.unet import DynamicUnet


class UnetBlock(nn.Module):
Expand Down Expand Up @@ -171,4 +174,12 @@ def optimize(self):
self.set_requires_grad(self.net_D, False)
self.opt_G.zero_grad()
self.backward_G()
self.opt_G.step()
self.opt_G.step()


# U-Net with ResNet18 backend
def build_res_unet(n_input=1, n_output=2, size=256):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
net_G = DynamicUnet(body, n_output, (size, size)).to(device)
return net_G