Skip to content

Commit

Permalink
Also trace SockeyeModel components when inference_only == False (in…
Browse files Browse the repository at this point in the history
…cludes CheckpointDecoder) (#1032)

* Trace checkpoint decoder

- Remove inference_only checks for model tracing
- Checkpoint decoder always runs in eval mode

* Version and changelog

* Grammar

* Whitespace

* Rename variable
  • Loading branch information
mjdenkowski authored Mar 23, 2022
1 parent 21b35e8 commit b24b2c1
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 39 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.7]

### Changed

- SockeyeModel components are now traced regardless of whether `inference_only` is set, including for the CheckpointDecoder during training.

## [3.1.6]

### Changed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.6'
__version__ = '3.1.7'
20 changes: 18 additions & 2 deletions sockeye/checkpoint_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(self,
self.bucket_width_source = bucket_width_source
self.length_penalty_alpha = length_penalty_alpha
self.length_penalty_beta = length_penalty_beta
# TODO(mdenkows): Trace encoder/decoder even though inference_only=False
self.model = model

with ExitStack() as exit_stack:
Expand Down Expand Up @@ -149,6 +148,12 @@ def decode_and_evaluate(self, output_name: Optional[str] = None) -> Dict[str, fl
"""

# 1. Translate

# Store original mode and set to eval mode in case the model is not yet
# traced.
original_mode = self.model.training
self.model.eval()

trans_wall_time = 0.0
translations = [] # type: List[List[str]]
with ExitStack() as exit_stack:
Expand All @@ -172,7 +177,11 @@ def decode_and_evaluate(self, output_name: Optional[str] = None) -> Dict[str, fl
avg_time = trans_wall_time / len(self.targets_sentences[0])
translations = list(zip(*translations)) # type: ignore

# Restore original model mode
self.model.train(original_mode)

# 2. Evaluate

metrics = {C.BLEU: evaluate.raw_corpus_bleu(hypotheses=translations[0],
references=self.targets_sentences[0],
offset=0.01),
Expand Down Expand Up @@ -202,9 +211,16 @@ def decode_and_evaluate(self, output_name: Optional[str] = None) -> Dict[str, fl
return metrics

def warmup(self):
"""Translate a single sentence to warm up the model"""
"""
Translate a single sentence to warm up the model. Set the model to eval
mode for tracing, translate the sentence, then set the model back to its
original mode.
"""
original_mode = self.model.training
self.model.eval()
one_sentence = [inference.make_input_from_multiple_strings(0, self.inputs_sentences[0])]
_ = self.translator.translate(one_sentence)
self.model.train(original_mode)


def parallel_subsample(parallel_sequences: List[List[Any]], sample_size: int, seed: int) -> List[Any]:
Expand Down
60 changes: 24 additions & 36 deletions sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def __init__(self,
vocab_size=self.config.vocab_target_size,
weight=output_weight)
if self.inference_only:
# Running this layer scripted with a newly initialized model can
# cause an overflow error.
self.output_layer = pt.jit.script(self.output_layer)

self.factor_output_layers = pt.nn.ModuleList()
Expand Down Expand Up @@ -167,19 +169,14 @@ def encode(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None) ->
:param valid_length: Optional Tensor of sequence lengths within this batch. Shape: (batch_size,)
:return: Encoder outputs, encoded output lengths
"""

if self.inference_only:
if self.traced_embedding_source is None:
logger.debug("Tracing embedding_source")
self.traced_embedding_source = pt.jit.trace(self.embedding_source, inputs)
source_embed = self.traced_embedding_source(inputs)
if self.traced_encoder is None:
logger.debug("Tracing encoder")
self.traced_encoder = pt.jit.trace(self.encoder, (source_embed, valid_length))
source_encoded, source_encoded_length = self.traced_encoder(source_embed, valid_length)
else:
source_embed = self.embedding_source(inputs)
source_encoded, source_encoded_length = self.encoder(source_embed, valid_length)
if self.traced_embedding_source is None:
logger.debug("Tracing embedding_source")
self.traced_embedding_source = pt.jit.trace(self.embedding_source, inputs)
source_embed = self.traced_embedding_source(inputs)
if self.traced_encoder is None:
logger.debug("Tracing encoder")
self.traced_encoder = pt.jit.trace(self.encoder, (source_embed, valid_length))
source_encoded, source_encoded_length = self.traced_encoder(source_embed, valid_length)
return source_encoded, source_encoded_length

def encode_and_initialize(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None,
Expand Down Expand Up @@ -238,29 +235,20 @@ def decode_step(self,
:return: logits, list of new model states, other target factor logits.
"""
if self.inference_only:
decode_step_inputs = [step_input, states]
if vocab_slice_ids is not None:
decode_step_inputs.append(vocab_slice_ids)
if self.traced_decode_step is None:
logger.debug("Tracing decode step")
decode_step_module = _DecodeStep(self.embedding_target,
self.decoder,
self.output_layer,
self.factor_output_layers)
self.traced_decode_step = pt.jit.trace(decode_step_module, decode_step_inputs)
# the traced module returns a flat list of tensors
decode_step_outputs = self.traced_decode_step(*decode_step_inputs)
step_output, *target_factor_outputs = decode_step_outputs[:self.num_target_factors]
new_states = decode_step_outputs[self.num_target_factors:]
else:
target_embed = self.embedding_target(step_input.unsqueeze(1))
decoder_out, new_states = self.decoder(target_embed, states)
decoder_out = decoder_out.squeeze(1)
# step_output: (batch_size, target_vocab_size or vocab_slice_ids)
step_output = self.output_layer(decoder_out, vocab_slice_ids)
target_factor_outputs = [fol(decoder_out) for fol in self.factor_output_layers]

decode_step_inputs = [step_input, states]
if vocab_slice_ids is not None:
decode_step_inputs.append(vocab_slice_ids)
if self.traced_decode_step is None:
logger.debug("Tracing decode step")
decode_step_module = _DecodeStep(self.embedding_target,
self.decoder,
self.output_layer,
self.factor_output_layers)
self.traced_decode_step = pt.jit.trace(decode_step_module, decode_step_inputs)
# the traced module returns a flat list of tensors
decode_step_outputs = self.traced_decode_step(*decode_step_inputs)
step_output, *target_factor_outputs = decode_step_outputs[:self.num_target_factors]
new_states = decode_step_outputs[self.num_target_factors:]
return step_output, new_states, target_factor_outputs

def forward(self, source, source_length, target, target_length): # pylint: disable=arguments-differ
Expand Down

0 comments on commit b24b2c1

Please sign in to comment.