-
Notifications
You must be signed in to change notification settings - Fork 28
/
navigation_transformer.py
executable file
·144 lines (128 loc) · 7.21 KB
/
navigation_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import math
import torch
from torch import nn
from torch.nn import functional as F
import einops
import numpy as np
from einops import rearrange, repeat
from allennlp.nn.util import get_range_vector, get_device_of, add_positional_features
from transformer import add_positional_features_2d, image_to_tiles, upsample_tiles, tiles_to_image, _get_std_from_tensor, Transformer, TransformerEncoder
import pdb
torch.manual_seed(12)
class NavigationTransformerEncoder(TransformerEncoder):
def __init__(self,
image_size: int,
patch_size: int,
language_embedder: torch.nn.Module,
n_layers: int,
n_classes: int = 2,
channels: int = 3,
n_heads: int = 8,
hidden_dim: int = 512,
ff_dim: int = 1024,
init_scale: int = 4,
dropout: float = 0.33,
embed_dropout: float = 0.33,
output_type: str = "per-pixel",
positional_encoding_type: str = "learned",
device: torch.device = "cpu",
locality_mask: bool = False,
locality_neighborhood: int = 5,
log_weights: bool = False):
super(NavigationTransformerEncoder, self).__init__(image_size=image_size,
patch_size=patch_size,
language_embedder=language_embedder,
n_layers_shared=n_layers,
n_layers_split=0,
n_classes=n_classes,
channels=channels,
n_heads=n_heads,
hidden_dim=hidden_dim,
ff_dim=ff_dim,
init_scale=init_scale,
dropout=dropout,
embed_dropout=embed_dropout,
output_type=output_type,
positional_encoding_type=positional_encoding_type,
device=device)
self.start_pos_projection = torch.nn.Linear(2, hidden_dim)
self.locality_mask = locality_mask
self.locality_neighborhood = locality_neighborhood
def get_neighbors(self, patch_idx, num_patches, neighborhood = 5):
image_w = int(num_patches**(1/2))
patch_idxs = np.arange(num_patches).reshape(image_w, image_w)
patch_row = int(patch_idx / image_w)
patch_col = patch_idx % image_w
neighbor_idxs = patch_idxs[max(0,patch_row - neighborhood):patch_row + neighborhood, max(0,patch_col - neighborhood):patch_col + neighborhood]
neighbor_idxs = neighbor_idxs.reshape(-1)
return neighbor_idxs
def get_image_local_mask(self, num_patches, image_dim, neighborhood = 5):
# make a mask so that each image patch can only attend to patches close to it
mask = torch.zeros((num_patches, num_patches))
for i in range(num_patches):
neighbors = self.get_neighbors(i, num_patches, neighborhood)
mask[i,neighbors] = 1
return mask.bool().to(self.device)
def _prepare_input(self, image, language, start_pos, mask = None):
# patchify
p = self.patch_size
image = image.permute(0,3,1,2).float()
image_input = image_to_tiles(image, p).to(self.device)
start_pos = start_pos.to(self.device)
batch_size, num_patches, __ = image_input.shape
# get mask
if mask is not None:
# all image regions allowed as input
long_mask = torch.ones((batch_size, num_patches + 1)).bool().to(self.device)
# concat in language mask
long_mask = torch.cat((long_mask, mask), dim = 1)
# concat in one for positional token and sep
pos_token_mask = torch.ones((batch_size, 2)).bool().to(self.device)
long_mask = torch.cat((long_mask, pos_token_mask), dim=1)
else:
long_mask = None
# project and positionally encode image
model_input = self.patch_projection(image_input)
# project and positionally encode language
language_input = torch.cat([self.language_embedder(x).unsqueeze(0) for x in language], dim = 0)
language_input = self.language_projection(language_input)
start_pos = self.start_pos_projection(start_pos).unsqueeze(1)
if self.positional_encoding_type == "fixed-separate":
# add fixed pos encoding to language and image separately
model_input = add_positional_features(model_input)
language_input = add_positional_features(language_input)
# repeat [SEP] across batch
sep_tokens_1 = repeat(self.sep_token, '() n d -> b n d', b=batch_size)
sep_tokens_2 = repeat(self.sep_token, '() n d -> b n d', b=batch_size)
# tack on [SEP]
model_input = torch.cat([model_input, sep_tokens_1], dim = 1)
# tack on language after [SEP]
model_input = torch.cat([model_input, language_input], dim = 1)
# tack another [SEP]
model_input = torch.cat([model_input, sep_tokens_2], dim = 1)
# tack on start position
model_input = torch.cat([model_input, start_pos], dim = 1)
else:
raise AssertionError(f"invalid positional type {self.positional_encoding_type}")
if self.locality_mask:
image_dim = int(num_patches**(1/2))
image_mask = self.get_image_local_mask(num_patches, image_dim, neighborhood = self.locality_neighborhood)
__, total_len, __ = model_input.shape
large_image_mask = torch.ones(total_len, total_len)
large_image_mask[0:image_mask.shape[0], 0:image_mask.shape[1]] = image_mask
large_image_mask = large_image_mask.bool().to(self.device)
return model_input, long_mask, large_image_mask, num_patches
def forward(self, batch_instance):
language = batch_instance['command']
image = batch_instance['input_image']
start_pos = batch_instance['start_position'].float()
language_mask = self.get_lang_mask(language)
tfmr_input, mask, attn_mask, n_patches = self._prepare_input(image, language, start_pos, mask = language_mask)
tfmr_output, __ = self.start_transformer(tfmr_input, mask = mask, attn_mask = attn_mask)
# trim off language
next_just_image_output = tfmr_output[:, 0:n_patches, :]
# run final MLP
next_classes = self.next_mlp_head(next_just_image_output)
# convert back to image
next_image_output = tiles_to_image(next_classes, self.patch_size, self.output_type, False).unsqueeze(-1)
return {"next_position": next_image_output}