Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Many methods currently default to creating tensors with
dtype=torch.float32
regardless of the data provided to them or the defaultdtype
configured PyTorch. This PR changes those methods to either keep the samedtype
as their arguments or fall back totorch.get_default_dtype()
if no reasonable choice is available.Similar changes are implemented for the
device
of said tensors.Since this PR is changing the default arguments of a number of methods, it should be considered a breaking change. The impact is likely negligable because the new fallback
torch.get_default_dtype()
aligns with the previous default oftorch.float32
if no explicit actions are taken by the user.Motivation and Context
Currently, it is rather cumbersome to track particles with double precision since many methods default to
torch.float32
. Implementing this change will increase the compatability applications that require tracking usingtorch.float64
(ortorch.float16
).Types of changes
Checklist
flake8
(required).pytest
tests pass (required).pytest
on a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line.