Skip to content

Commit

Permalink
Enable bfloat16 inference and fix unrelated NumPy issue (#1077)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjdenkowski authored Dec 20, 2022
1 parent 4dba5a3 commit a08216c
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 20 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.30]

### Added

- Added support for `--dtype bfloat16` to `sockeye-translate`, `sockeye-score`, and `sockeye-quantize`.

### Fixed

- Fixed compatibility issue with `numpy==1.24.0` by using `pickle` instead of `numpy` to save/load `ParallelSampleIter` data permutations.

## [3.1.29]

### Changed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.29'
__version__ = '3.1.30'
18 changes: 13 additions & 5 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def add_quantize_args(params):
f'"{C.PARAMS_BEST_NAME}.{C.DTYPE_FP32}" and "{C.CONFIG_NAME}.{C.DTYPE_FP32}").')
params.add_argument('--dtype',
default=C.DTYPE_FP16,
choices=[C.DTYPE_BF16, C.DTYPE_FP16, C.DTYPE_FP32],
choices=[C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_BF16],
help='Target data type for quantization. Default: %(default)s.')


Expand Down Expand Up @@ -771,7 +771,9 @@ def add_model_parameters(params):
help='The type of weight tying. source embeddings=src, target embeddings=trg, '
'target softmax weight matrix=softmax. Default: %(default)s.')

model_params.add_argument('--dtype', default=C.DTYPE_FP32, choices=[C.DTYPE_FP32, C.DTYPE_FP16],
model_params.add_argument('--dtype',
default=C.DTYPE_FP32,
choices=[C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_BF16],
help="Data type.")
add_clamp_to_dtype_arg(model_params)

Expand Down Expand Up @@ -1177,7 +1179,9 @@ def add_score_cli_args(params):
help='Controls peakiness of model predictions. Values < 1.0 produce '
'peaked predictions, values > 1.0 produce smoothed distributions.')

params.add_argument('--dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8],
params.add_argument('--dtype',
default=None,
choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_BF16, C.DTYPE_INT8],
help="Data type. Default: infers from saved model.")

add_logging_args(params)
Expand Down Expand Up @@ -1210,7 +1214,9 @@ def add_state_generation_args(params):
params.add_argument("--output-dir", "-o", default=None,
help="The path to the directory that stores the decoder states.")

params.add_argument('--dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8],
params.add_argument('--dtype',
default=None,
choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8],
help="Data type. Default: infers from saved model.")

add_logging_args(params)
Expand Down Expand Up @@ -1381,7 +1387,9 @@ def add_inference_args(params):
add_length_penalty_args(decode_params)
add_brevity_penalty_args(decode_params)

decode_params.add_argument('--dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8],
decode_params.add_argument('--dtype',
default=None,
choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_BF16, C.DTYPE_INT8],
help="Data type. Default: infers from saved model.")
add_clamp_to_dtype_arg(decode_params)

Expand Down
10 changes: 4 additions & 6 deletions sockeye/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,8 +1865,8 @@ def save_state(self, fname: str):
with open(fname, "wb") as fp:
pickle.dump(self.batch_indices, fp)
pickle.dump(self.curr_batch_index, fp)
np.save(fp, [a.numpy() for a in self.inverse_data_permutations], allow_pickle=True)
np.save(fp, [a.numpy() for a in self.data_permutations], allow_pickle=True)
pickle.dump(self.inverse_data_permutations, fp)
pickle.dump(self.data_permutations, fp)

def load_state(self, fname: str):
"""
Expand All @@ -1881,10 +1881,8 @@ def load_state(self, fname: str):
with open(fname, "rb") as fp:
self.batch_indices = pickle.load(fp)
self.curr_batch_index = pickle.load(fp)
inverse_data_permutations = [torch.from_numpy(a).long() for a in
np.load(fp, allow_pickle=True)] # pylint: disable=unexpected-keyword-arg
data_permutations = [torch.from_numpy(a).long() for a in
np.load(fp, allow_pickle=True)] # pylint: disable=unexpected-keyword-arg
inverse_data_permutations = pickle.load(fp)
data_permutations = pickle.load(fp)

# Right after loading the iterator state, next() should be called
self.curr_batch_index -= 1
Expand Down
6 changes: 5 additions & 1 deletion sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,11 @@ def _get_best_translations(self, result: SearchResult) -> List[Translation]:
"""
best_hyp_indices = result.best_hyp_indices.cpu().numpy()
best_word_indices = result.best_word_indices.cpu().numpy()
accumulated_scores = result.accumulated_scores.cpu().numpy()
result_accumulated_scores_cpu = result.accumulated_scores.cpu()
if self.dtype == pt.bfloat16:
# NumPy does not currently support bfloat16. Use float32 instead.
result_accumulated_scores_cpu = result_accumulated_scores_cpu.to(dtype=pt.float32)
accumulated_scores = result_accumulated_scores_cpu.numpy()
lengths = result.lengths.cpu().numpy()
estimated_reference_lengths = None
if result.estimated_reference_lengths is not None:
Expand Down
7 changes: 5 additions & 2 deletions sockeye/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,11 @@ def score_batch(self, batch: data_io.Batch):
logger.debug("Tracing batch_scorer")
self.traced_batch_scorer = pt.jit.trace(self.batch_scorer, scorer_inputs, strict=False)
scores = self.traced_batch_scorer(*scorer_inputs) # (batch, num_target_factors)

return scores.cpu().numpy()
scores_cpu = scores.cpu()
if self.model.dtype == pt.bfloat16:
# NumPy does not currently support bfloat16. Use float32 instead.
scores_cpu = scores_cpu.to(dtype=pt.float32)
return scores_cpu.numpy()

@pt.inference_mode(True)
def score(self, score_iter: data_io.BaseParallelSampleIter, output_handler: OutputHandler):
Expand Down
4 changes: 2 additions & 2 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def check_arg_compatibility(args: argparse.Namespace):
check_condition(not (args.amp and args.apex_amp), 'Use either --amp (safer) or --apex-amp (faster).')

if args.dtype != C.DTYPE_FP32:
logger.warning('Specifying a non-float32 dtype to sockeye.train has no effect. Use --amp or --apex-amp for '
'mixed precision training.')
logger.warning('Specifying a non-float32 dtype to sockeye.train has no effect. For 16-bit or mixed precision '
'training, use one of the following: --amp --apex-amp --deepspeed-fp16 --deepspeed-bf16')

if args.local_rank is not None:
check_condition(not args.amp and not args.apex_amp, 'DeepSpeed mode does not support --amp or --apex-amp. '
Expand Down
6 changes: 4 additions & 2 deletions test/integration/test_seq_copy_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# permissions and limitations under the License.
import logging
import os
import platform
import sys
from tempfile import TemporaryDirectory
from typing import List
Expand Down Expand Up @@ -132,7 +133,8 @@
" --checkpoint-interval 2 --optimizer adam --initial-learning-rate 0.01 --clamp-to-dtype",
"--beam-size 2 --clamp-to-dtype",
False, 0, 0),
# Basic transformer, training only the decoder
# Basic transformer, training only the decoder with bfloat16 inference when
# running on Linux
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
" --transformer-feed-forward-num-hidden 16"
Expand All @@ -141,7 +143,7 @@
" --batch-size 2 --max-updates 2 --batch-type sentence --decode-and-evaluate 2"
" --checkpoint-interval 2 --optimizer adam --initial-learning-rate 0.01"
" --fixed-param-strategy " + C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_DECODER,
"--beam-size 2",
"--beam-size 2" + (" --dtype bfloat16" if platform.system() == "Linux" else ""),
False, 0, 0),
]

Expand Down
5 changes: 4 additions & 1 deletion test/unit/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
def mock_translator(batch_size: int = 1,
beam_size: int = 5,
nbest_size: int = 1,
num_source_factors: int = 1):
num_source_factors: int = 1,
dtype: pt.dtype = pt.float32):
"""
Creates a fake translator object but with real values for things that we need.
This lets us avoid a messy call to the constructor.
Expand All @@ -58,12 +59,14 @@ def mock_translator(batch_size: int = 1,
def mock_model():
t_mock = Mock(sockeye.model.SockeyeModel)
t_mock.num_source_factors = num_source_factors
t_mock.dtype = dtype
return t_mock

translator.batch_size = batch_size
translator.beam_size = beam_size
translator.nbest_size = nbest_size
translator.models = [mock_model()]
translator.dtype = translator.models[0].dtype
translator.zeros_array = pt.zeros(beam_size, dtype=pt.int)
translator.inf_array = pt.full((batch_size * beam_size,), fill_value=np.inf, dtype=pt.float32)
translator.inf_array = translator.inf_array[:beam_size]
Expand Down

0 comments on commit a08216c

Please sign in to comment.