Skip to content

Commit

Permalink
GH-1377: fix torch 1.4.0 error
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Jan 24, 2020
1 parent c01a4da commit b9a6ebe
Showing 1 changed file with 34 additions and 26 deletions.
60 changes: 34 additions & 26 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2930,32 +2930,6 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
longest_token_sequence_in_batch: int = max(lengths)

# initialize zero-padded word embeddings tensor
# sentence_tensor = torch.zeros(
# [
# len(sentences),
# longest_token_sequence_in_batch,
# self.embeddings.embedding_length,
# ],
# dtype=torch.float,
# device=flair.device,
# )
#
# for s_id, sentence in enumerate(sentences):
# # fill values with word embeddings
# all_embs = list()
#
# for index_token, token in enumerate(sentence):
# embs = token.get_each_embedding()
# if not all_embs:
# all_embs = [list() for _ in range(len(embs))]
# for index_emb, emb in enumerate(embs):
# all_embs[index_emb].append(emb)
#
# concat_word_emb = [torch.stack(embs) for embs in all_embs]
# concat_sentence_emb = torch.cat(concat_word_emb, dim=1)
# sentence_tensor[s_id][: len(sentence)] = concat_sentence_emb

pre_allocated_zero_tensor = torch.zeros(
self.embeddings.embedding_length * longest_token_sequence_in_batch,
dtype=torch.float,
Expand Down Expand Up @@ -3023,6 +2997,40 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
sentence = sentences[sentence_no]
sentence.set_embedding(self.name, embedding)

def _apply(self, fn):
major, minor, build, *_ = (int(info)
for info in torch.__version__.split('.'))

# fixed RNN change format for torch 1.4.0
if major >= 1 and minor >= 4:
for child_module in self.children():
if isinstance(child_module, torch.nn.RNNBase):
_flat_weights_names = []
num_direction = None

if child_module.__dict__["bidirectional"]:
num_direction = 2
else:
num_direction = 1
for layer in range(child_module.__dict__["num_layers"]):
for direction in range(num_direction):
suffix = "_reverse" if direction == 1 else ""
param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"]
if child_module.__dict__["bias"]:
param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"]
param_names = [
x.format(layer, suffix) for x in param_names
]
_flat_weights_names.extend(param_names)

setattr(child_module, "_flat_weights_names",
_flat_weights_names)

child_module._apply(fn)

else:
super()._apply(fn)


@deprecated(
version="0.4",
Expand Down

0 comments on commit b9a6ebe

Please sign in to comment.