Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Making TorchForce CUDA-graph aware (#103)
* Add CUDA graph draft * Initialize energy and force tensors in the GPU. * Add comment on graph capture * Catch torch exception if the model fails to capture. * Replay graph just after construction Finish capturing before rethrowing if an exception occurred during capture * Add python-side test script for CUDA graphs * Implement properties * Update the Python bindings * Unify the API for properties * Pass the propery map to the constructor * Skip graph tests if no GPU is present * Guard CUDA graph behavior with the CUDA_GRAPH_ENABLE macro * Check validity of the useCUDAGraphs property * Add missing bracket to openmmtorch.i * Fix bug in useCUDAgraph selection * Update tests * Add test for get/setProperty * Update documentation with new functionality * Add a CUDA graph test for a model that returns only energy * Add contributors * Reset pos grads after graph capture. Make energy and force tensors persistent. * Add tests that execute the model many times to catch bugs related with CUDA graph capture * Run formatter * Warmup model for several steps * Include gradient reset into the graph * Do not reset energy and force tensors before graph capture * Remove unnecessary line * Add tests for larger number of particles * Remove unnecessary compilation guard now that Pytorch 1.10 is not supported * Simplify getTensorPointer now that Pytorch 1.7 is not supported * Change addForcesToOpenMM to addForces * Change execute_graph to executeGraph * Wrap graph warming up in a try/catch block * Add correctness test for modules that only provide energy * Revert "Add correctness test for modules that only provide energy" This reverts commit d20f4bf. * Explicit conversion to correct type in getTensorPointer * Added a new property for TorchForce, CUDAGraphWarmupSteps. * Clarify docs * Document properties * Throw if requested property does not exist * Change getProperty(string) to getProperties() * Add getProperties to python wrappers * Fix formatting * Set default properties * Update tests * Update some comments --------- Co-authored-by: Raimondas Galvelis <r.galvelis@acellera.com>
- Loading branch information