-
Notifications
You must be signed in to change notification settings - Fork 0
/
play_with.py
99 lines (84 loc) · 3.29 KB
/
play_with.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import torch
from torchvision import datasets, transforms
from PIL import Image, ImageFont, ImageDraw, ImageEnhance
from dataset import load_vocab
import utils
import argparse
import torch
from model import Captor
import os
import random
def create_embeds(model, image):
image = image[None]
embeds = model.net.enc(image)
embeds = embeds.view(embeds.size(0), -1)
return model.net.embeds(embeds)
def beamsearch(model, device, image, vocab, return_sentence=True):
model.net.eval()
embeds = create_embeds(model, image.to(device))
l = 0
caps = []
gen_cap = []
count = 0
cap_tens = None
while True:
predict = model.net.dec(embeds, cap_tens, [l])[:, :, -1]
id = np.argmax(predict.cpu().data.numpy(), axis=1)[0]
if id == 1 or count > 20:
break
gen_cap.append(vocab[id])
caps.append(id)
cap_tens = torch.Tensor([caps]).long().to(device)
l += 1
count += 1
if return_sentence:
return ' '.join(gen_cap[1:])
return gen_cap[1:]
def main(args):
preprocess = transforms.Compose([
transforms.Resize([args.im_size] * 2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.407, 0.457, 0.485], # subtract imagenet mean
std=[1, 1, 1]),
])
use_cuda = torch.cuda.is_available()
torch.manual_seed(random.randint(1, 10000))
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
words = load_vocab()
vocab = {i: w for w, i in words.items()}
model = Captor(args.lr, args.weight_decay, args.lr_decay_rate, len(words), args.embed_size)
model.to(device)
model.load_checkpoint(args.ckpt_path)
im = Image.open(os.path.join('images', args.fn))
#im.show()
img = preprocess(im)
caption = beamsearch(model, device, img, vocab)
img_ = transforms.Resize((600, 800))(im)
draw = ImageDraw.Draw(img_)
draw.rectangle(((0, 0), (800, 40)), fill='black')
draw.text((20, 20), caption, (255, 255, 255), font=ImageFont.truetype('FreeMono.ttf', 20))
img_.show()
print('auto caption: {}'.format(caption))
ans = input('Do you want to save image? (y/n):')
if ans == 'y' or ans == 'Y':
if not os.path.exists('saveims'):
os.makedirs('saveims')
img_.save(os.path.join('saveims', args.fn))
else:
return
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Image Captioning')
parser.add_argument('--im-size', nargs='?', type=int, default=299)
parser.add_argument('--embed-size', nargs='?', type=int, default=512)
parser.add_argument('--seq-len', help='max sequence length', nargs='?', type=int, default=100)
parser.add_argument('--lr-decay-interval', nargs='?', type=int, default=2000)
parser.add_argument('--lr-decay-rate', nargs='?', type=float, default=1e-5)
parser.add_argument('--epochs', nargs='?', type=int, default=100)
parser.add_argument('--lr', nargs='?', type=float, default=1e-3)
parser.add_argument('--weight-decay', nargs='?', type=float, default=1e-2)
parser.add_argument('--ckpt-path', nargs='?', default='checkpoints')
parser.add_argument('--fn', nargs='?', default='test.jpg')
args = parser.parse_args()
main(args)