We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cost = tf.contrib.seq2seq.sequence_loss(training_logits, targets, masks)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [5632] vs. [6400]
就是我的target是(256,25) 可是输出得到的training_logits却是(256, 22, 358)358:词表数
我改了一下,这样就对了
def pad_batch_sentence(batch, max_length, pad_id): # max_length = max([len(sentence) for sentence in batch]) return [sentence + [pad_id] * (max_length - len(sentence)) for sentence in batch] def get_batches(sources, targets, batch_size): for batch_i in range(0, len(sources) // batch_size): start_i = batch_i * batch_size # Slice the right amount for the batch sources_batch = sources[start_i:start_i + batch_size] targets_batch = targets[start_i:start_i + batch_size] pad_idx = source_vocab_to_int.get("<PAD>") sources_batch_pad = np.array(pad_batch_sentence(sources_batch, max_source_sentence_length, pad_idx)) targets_batch_pad = np.array(pad_batch_sentence(targets_batch, max_target_sentence_length, pad_idx)) # Need the lengths for the _lengths parameters # 不应该是对pad过的batch做长度的计算,因为都是25 targets_lengths = [] for target in targets_batch_pad: targets_lengths.append(len(target)) source_lengths = [] for source in sources_batch_pad: source_lengths.append(len(source)) yield sources_batch_pad, targets_batch_pad, source_lengths, targets_lengths 可是这样传入的source_lengths都是(20,20,20...) targets_lengths都是(25, 25, 25...)
The text was updated successfully, but these errors were encountered:
我也觉得这块有点问题,这样source长度全是padding以后的最大长度。。
Sorry, something went wrong.
改了以后会报错。。。
No branches or pull requests
cost = tf.contrib.seq2seq.sequence_loss(training_logits, targets, masks)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [5632] vs. [6400]
就是我的target是(256,25)
可是输出得到的training_logits却是(256, 22, 358)358:词表数
我改了一下,这样就对了
The text was updated successfully, but these errors were encountered: