Skip to content

Commit

Permalink
Fix synchronization issues in RNG checkpointing utils (#5273)
Browse files Browse the repository at this point in the history
This PR fixes two synchronization bugs. Both problems happen when you're restoring from checkpoint in random GPU operators and are due to lack of synchronization after H2D & D2D copies.

---------

Signed-off-by: Szymon Karpiński <skarpinski@nvidia.com>
  • Loading branch information
szkarpinski authored Jan 16, 2024
1 parent 4764966 commit b98cba1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
11 changes: 7 additions & 4 deletions dali/operators/random/rng_checkpointing_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class RngCheckpointUtils<GPUBackend, curand_states> {
static void SaveState(OpCheckpoint &cpt, AccessOrder order, const curand_states &rng) {
cpt.SetOrder(order);
cpt.MutableCheckpointState() = rng.copy(order);
// The pipeline will perform host synchronization before serializing the checkpoints.
// TODO(skarpinski) Move synchronization out from pipeline's GetCheckpoint.
}

static void RestoreState(const OpCheckpoint &cpt, curand_states &rng) {
Expand All @@ -85,10 +87,11 @@ class RngCheckpointUtils<GPUBackend, curand_states> {
static void DeserializeCheckpoint(OpCheckpoint &cpt, const std::string &data) {
auto deserialized = SnapshotSerializer().Deserialize<std::vector<curandState>>(data);
curand_states states(deserialized.size());
CUDA_CALL(cudaMemcpy(states.states(),
deserialized.data(),
sizeof(curandState) * deserialized.size(),
cudaMemcpyHostToDevice));
CUDA_CALL(cudaMemcpyAsync(states.states(),
deserialized.data(),
sizeof(curandState) * deserialized.size(),
cudaMemcpyHostToDevice, cudaStreamDefault));
CUDA_CALL(cudaStreamSynchronize(cudaStreamDefault));
cpt.MutableCheckpointState() = states;
}
};
Expand Down
5 changes: 3 additions & 2 deletions dali/operators/util/randomizer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ struct curand_states {
}

DALI_HOST inline void set(const curand_states &other) {
CUDA_CALL(cudaMemcpy(states_, other.states_, sizeof(curandState) * len_,
cudaMemcpyDeviceToDevice));
CUDA_CALL(cudaMemcpyAsync(states_, other.states_, sizeof(curandState) * len_,
cudaMemcpyDeviceToDevice, cudaStreamDefault));
CUDA_CALL(cudaStreamSynchronize(cudaStreamDefault));
}

private:
Expand Down

0 comments on commit b98cba1

Please sign in to comment.