Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve default dtype selection #254

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

Hespe
Copy link
Member

@Hespe Hespe commented Sep 17, 2024

Description

Many methods currently default to creating tensors with dtype=torch.float32 regardless of the data provided to them or the default dtype configured PyTorch. This PR changes those methods to either keep the same dtype as their arguments or fall back to torch.get_default_dtype() if no reasonable choice is available.
Similar changes are implemented for the device of said tensors.

Since this PR is changing the default arguments of a number of methods, it should be considered a breaking change. The impact is likely negligable because the new fallback torch.get_default_dtype() aligns with the previous default of torch.float32 if no explicit actions are taken by the user.

Motivation and Context

Currently, it is rather cumbersome to track particles with double precision since many methods default to torch.float32. Implementing this change will increase the compatability applications that require tracking using torch.float64 (or torch.float16).

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code and checked that formatting passes (required).
  • I have have fixed all issues found by flake8 (required).
  • I have ensured that all pytest tests pass (required).
  • I have run pytest on a machine with a CUDA GPU and made sure all tests pass (required).
  • I have checked that the documentation builds (required).

Note: We are using a maximum length of 88 characters per line.

@Hespe Hespe added the enhancement New feature or request label Sep 17, 2024
@Hespe Hespe linked an issue Sep 17, 2024 that may be closed by this pull request
@Hespe
Copy link
Member Author

Hespe commented Sep 17, 2024

I've noticed that the docstrings for some of the methods I have edited are not consistent with the method signatures, especially in case of the ParticleBeam. This is in addition to the issues already noted in #238

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fall back to PyTorch default dtype if no explicit type is provided
1 participant