From b9a6ebe31b5ab1f28c18b4a6ed973f317dd74bbb Mon Sep 17 00:00:00 2001 From: alanakbik Date: Fri, 24 Jan 2020 14:45:04 +0100 Subject: [PATCH] GH-1377: fix torch 1.4.0 error --- flair/embeddings.py | 60 +++++++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/flair/embeddings.py b/flair/embeddings.py index 289fca690c..c7fddd26d3 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -2930,32 +2930,6 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): lengths: List[int] = [len(sentence.tokens) for sentence in sentences] longest_token_sequence_in_batch: int = max(lengths) - # initialize zero-padded word embeddings tensor - # sentence_tensor = torch.zeros( - # [ - # len(sentences), - # longest_token_sequence_in_batch, - # self.embeddings.embedding_length, - # ], - # dtype=torch.float, - # device=flair.device, - # ) - # - # for s_id, sentence in enumerate(sentences): - # # fill values with word embeddings - # all_embs = list() - # - # for index_token, token in enumerate(sentence): - # embs = token.get_each_embedding() - # if not all_embs: - # all_embs = [list() for _ in range(len(embs))] - # for index_emb, emb in enumerate(embs): - # all_embs[index_emb].append(emb) - # - # concat_word_emb = [torch.stack(embs) for embs in all_embs] - # concat_sentence_emb = torch.cat(concat_word_emb, dim=1) - # sentence_tensor[s_id][: len(sentence)] = concat_sentence_emb - pre_allocated_zero_tensor = torch.zeros( self.embeddings.embedding_length * longest_token_sequence_in_batch, dtype=torch.float, @@ -3023,6 +2997,40 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): sentence = sentences[sentence_no] sentence.set_embedding(self.name, embedding) + def _apply(self, fn): + major, minor, build, *_ = (int(info) + for info in torch.__version__.split('.')) + + # fixed RNN change format for torch 1.4.0 + if major >= 1 and minor >= 4: + for child_module in self.children(): + if isinstance(child_module, torch.nn.RNNBase): + _flat_weights_names = [] + num_direction = None + + if child_module.__dict__["bidirectional"]: + num_direction = 2 + else: + num_direction = 1 + for layer in range(child_module.__dict__["num_layers"]): + for direction in range(num_direction): + suffix = "_reverse" if direction == 1 else "" + param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] + if child_module.__dict__["bias"]: + param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] + param_names = [ + x.format(layer, suffix) for x in param_names + ] + _flat_weights_names.extend(param_names) + + setattr(child_module, "_flat_weights_names", + _flat_weights_names) + + child_module._apply(fn) + + else: + super()._apply(fn) + @deprecated( version="0.4",