Skip to content

Commit

Permalink
added [torch] install option
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 17, 2024
1 parent ad77668 commit 1148a0e
Show file tree
Hide file tree
Showing 13 changed files with 151 additions and 74 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install uv
uv pip install ${{ matrix.uv-arg }} --system -r deps/dev_requirements.txt
uv pip install ${{ matrix.uv-arg }} --system -r deps/requirements.txt
uv pip install ${{ matrix.uv-arg }} --system -r deps/torch_geometric_requirements.txt
uv pip install ${{ matrix.uv-arg }} --system -r deps/torch_requirements.txt
Expand Down
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**Ramannoodle** is a Python API for efficiently calculating Raman spectra from first principles calculations. Ramannoodle supports molecular-dynamics- and phonon-based Raman calculations and includes interfaces with VASP.

Ramannoodle is designed from the ground up to be:
Ramannoodle aims to be:

1. **EFFICIENT**

Expand All @@ -33,9 +33,19 @@ Ramannoodle includes interfaces with:

Ramannoodle can be installed via pip:

`
```
$ pip install ramannoodle
`
```

Due to idiosyncrasies with PyTorch's build system, installing ramannoodle's machine learning modules is slightly more involved. First, PyTorch must be installed ([pip commands](https://pytorch.org/get-started/locally/)). Then, corresponding torch-scatter and torch-sparse packages must be installed. Finally, Ramannoodle can then be installed with the appropriate options.

For example, installation on a Linux system using PyTorch 2.4.1 (cpu implementation) is done as follows:

```
$ pip install torch==2.4.1+cpu --index-url https://download.pytorch.org/whl/cpu
$ pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
$ pip install ramannoodle[torch]
```

## Documentation

Expand Down
5 changes: 5 additions & 0 deletions deps/dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
flake8 == 7.1.0
pre-commit == 3.7.1
pylint == 3.2.6
pytest == 8.3.1
setuptools == 74.1.2
31 changes: 10 additions & 21 deletions deps/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
# numpy, scipy recommendations: https://scientific-python.org/specs/spec-0000/

defusedxml >= 0.6.0;python_version=='3.10' # minimum working
defusedxml >= 0.6.0;python_version=='3.11' # minimum working
defusedxml >= 0.6.0;python_version=='3.12' # minimum working
flake8 == 7.1.0
numpy >= 1.24.0;python_version=='3.10' # minimum recommended
numpy >= 1.24.0;python_version=='3.11' # minimum recommended
numpy >= 1.26.0;python_version=='3.12' # minimum working
pre-commit == 3.7.1
pylint == 3.2.6
pytest == 8.3.1
scipy >= 1.10.0;python_version=='3.10' # minimum recommended
scipy >= 1.10.0;python_version=='3.11' # minimum recommended
scipy >= 1.11.2;python_version=='3.12' # minimum working
setuptools == 74.1.2
spglib >= 1.16.4;python_version=='3.10' # minimum working
spglib >= 1.16.4;python_version=='3.11' # minimum working
spglib >= 1.16.4;python_version=='3.12' # minimum working
tabulate >= 0.8.8;python_version=='3.10' # minimum working
tabulate >= 0.8.8;python_version=='3.11' # minimum working
tabulate >= 0.8.8;python_version=='3.12' # minimum working
tqdm >= 2.0
defusedxml >= 0.6.0 # min working
numpy >= 1.24.0;python_version=='3.10' # min recommended
numpy >= 1.24.0;python_version=='3.11' # min recommended
numpy >= 1.26.0;python_version=='3.12' # min working
scipy >= 1.10.0;python_version=='3.10' # min recommended
scipy >= 1.10.0;python_version=='3.11' # min recommended
scipy >= 1.11.2;python_version=='3.12' # min working
spglib >= 1.16.4 # min working
tabulate >= 0.8.8 # min working
tqdm >= 2.0 # min working
36 changes: 16 additions & 20 deletions deps/torch_geometric_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
aiohttp >= 3.8.0;python_version=='3.10'
aiohttp >= 3.8.3;python_version=='3.11'
aiohttp >= 3.9.0;python_version=='3.12'
dill >= 0.3.4
frozenlist >= 1.2.0;python_version=='3.10'
frozenlist >= 1.3.3;python_version=='3.11'
frozenlist >= 1.4.1;python_version=='3.12'
fsspec>= 2021.4.0;python_version=='3.10'
fsspec>= 2021.4.0;python_version=='3.11'
fsspec>=2021.4.0;python_version=='3.12'
jinja2 >= 3.0.2
pyparsing >= 3.0.0
scikit-learn >= 1.2.0;python_version=='3.10'
scikit-learn >= 1.2.0;python_version=='3.11'
scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=="darwin"
scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=="linux"
scikit-learn >= 1.4.0;python_version=='3.12' and sys_platform=="win32"
torch_geometric >= 2.3.0;python_version=='3.10'
torch_geometric >= 2.3.0;python_version=='3.11'
torch_geometric >= 2.3.0;python_version=='3.12'
aiohttp >= 3.8.0;python_version=='3.10' # min working
aiohttp >= 3.8.3;python_version=='3.11' # min working
aiohttp >= 3.9.0;python_version=='3.12' # min working
dill >= 0.3.4 # min working
frozenlist >= 1.2.0;python_version=='3.10' # min working
frozenlist >= 1.3.3;python_version=='3.11' # min working
frozenlist >= 1.4.1;python_version=='3.12' # min working
fsspec>= 2021.4.0;python_version=='3.10' # min working
jinja2 >= 3.0.2 # min working
pyparsing >= 3.0.0 # min working
scikit-learn >= 1.2.0;python_version=='3.10' # min working
scikit-learn >= 1.2.0;python_version=='3.11' # min working
scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=='darwin' # min working
scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=='linux' # min working
scikit-learn >= 1.4.0;python_version=='3.12' and sys_platform=='win32' # min working
torch_geometric >= 2.3.0 # min working
6 changes: 3 additions & 3 deletions deps/torch_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--index-url https://download.pytorch.org/whl/cpu
torch==2.4.1;sys_platform=="darwin"
torch==2.4.1+cpu;sys_platform=="linux"
torch==2.4.1+cpu;sys_platform=="win32"
torch==2.4.1;sys_platform=='darwin'
torch==2.4.1+cpu;sys_platform=='linux'
torch==2.4.1+cpu;sys_platform=='win32'
46 changes: 28 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,37 @@ requires-python = ">=3.10"
keywords = ["raman", "spectrum", "vasp", "dft", "phonons", "molecular", "dynamics", "polarizability" ]
license = {text = "MIT"}
dependencies = [
"numpy >= 1.24.0;python_version=='3.10'", # minimum recommended
"numpy >= 1.24.0;python_version=='3.11'", # minimum recommended
"numpy >= 1.26.0;python_version=='3.12'", # minimum working
"scipy >= 1.10.0;python_version=='3.10'", # minimum recommended
"scipy >= 1.10.0;python_version=='3.11'", # minimum recommended
"scipy >= 1.11.2;python_version=='3.12'", # minimum working
"spglib >= 1.16.4;python_version=='3.10'", # minimum working
"spglib >= 1.16.4;python_version=='3.11'", # minimum working
"spglib >= 1.16.4;python_version=='3.12'", # minimum working
"defusedxml >= 0.6.0;python_version=='3.10'", # minimum working
"defusedxml >= 0.6.0;python_version=='3.11'", # minimum working
"defusedxml >= 0.6.0;python_version=='3.12'", # minimum working
"tabulate >= 0.8.8;python_version=='3.10'", # minimum working
"tabulate >= 0.8.8;python_version=='3.11'", # minimum working
"tabulate >= 0.8.8;python_version=='3.12'", # minimum working
"torch >= 2.4.0",
"torch-geometric >= 2.5.3",
"torch-sparse >= 0.6.18",
"defusedxml >= 0.6.0", # min working
"numpy >= 1.24.0;python_version=='3.10'", # min recommended
"numpy >= 1.24.0;python_version=='3.11'", # min recommended
"numpy >= 1.26.0;python_version=='3.12'", # min working
"scipy >= 1.10.0;python_version=='3.10'", # min recommended
"scipy >= 1.10.0;python_version=='3.11'", # min recommended
"scipy >= 1.11.2;python_version=='3.12'", # min working
"spglib >= 1.16.4", # min working
"tabulate >= 0.8.8", # min working
"tqdm >= 2.0", # min working
]

[project.optional-dependencies]
torch = [
"aiohttp >= 3.8.0;python_version=='3.10'", # min working
"aiohttp >= 3.8.3;python_version=='3.11'", # min working
"aiohttp >= 3.9.0;python_version=='3.12'", # min working
"dill >= 0.3.4", # min working
"frozenlist >= 1.2.0;python_version=='3.10'", # min working
"frozenlist >= 1.3.3;python_version=='3.11'", # min working
"frozenlist >= 1.4.1;python_version=='3.12'", # min working
"fsspec>= 2021.4.0", # min working
"jinja2 >= 3.0.2", # min working
"pyparsing >= 3.0.0", # min working
"scikit-learn >= 1.2.0;python_version=='3.10'", # min working
"scikit-learn >= 1.2.0;python_version=='3.11'", # min working
"scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=='darwin'", # min working
"scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=='linux'", # min working
"scikit-learn >= 1.4.0;python_version=='3.12' and sys_platform=='win32'", # min working
"torch_geometric >= 2.3.0", # min working
]

[project.urls]
Documentation = "https://ramannoodle.readthedocs.io/en/latest/"
Expand Down
8 changes: 6 additions & 2 deletions ramannoodle/io/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

from ramannoodle.structure.reference import ReferenceStructure
import ramannoodle.io.vasp as vasp_io
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore

# These map between file formats and appropriate IO functions.
_PHONON_READERS = {
Expand Down Expand Up @@ -189,7 +193,7 @@ def read_structure_and_polarizability(
def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
file_format: str,
) -> PolarizabilityDataset:
) -> dataset.PolarizabilityDataset:
"""Read polarizability dataset from files.
Parameters
Expand Down
14 changes: 11 additions & 3 deletions ramannoodle/io/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
IncompatibleStructureException,
)
from ramannoodle.globals import ATOM_SYMBOLS
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore


def _skip_file_until_line_contains(file: TextIO, content: str) -> str:
Expand Down Expand Up @@ -95,7 +99,7 @@ def _read_polarizability_dataset(
[str | Path],
tuple[NDArray[np.float64], list[int], NDArray[np.float64], NDArray[np.float64]],
],
) -> PolarizabilityDataset:
) -> dataset.PolarizabilityDataset:
"""Read polarizability dataset from OUTCAR files.
Parameters
Expand All @@ -114,7 +118,11 @@ def _read_polarizability_dataset(
File has an unexpected format.
IncompatibleFileException
File is incompatible with the dataset.
ModuleNotFoundError
Torch installation could not be found.
"""
if not dataset.TORCH_PRESENT:
raise ModuleNotFoundError("torch installation not found")
filepaths = pathify_as_list(filepaths)

lattices: list[NDArray[np.float64]] = []
Expand Down Expand Up @@ -143,7 +151,7 @@ def _read_polarizability_dataset(
positions_list.append(positions)
polarizabilities.append(polarizability)

return PolarizabilityDataset(
return dataset.PolarizabilityDataset(
np.array(lattices),
atomic_numbers_list,
np.array(positions_list),
Expand Down
8 changes: 6 additions & 2 deletions ramannoodle/io/vasp/outcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from ramannoodle.dynamics.phonon import Phonons
from ramannoodle.dynamics.trajectory import Trajectory
from ramannoodle.structure.reference import ReferenceStructure
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore


# Utilities for OUTCAR. Warning: some of these functions partially read files.
Expand Down Expand Up @@ -400,7 +404,7 @@ def read_structure_and_polarizability(

def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
) -> PolarizabilityDataset:
) -> dataset.PolarizabilityDataset:
"""Read polarizability dataset from OUTCAR files.
Parameters
Expand Down
8 changes: 6 additions & 2 deletions ramannoodle/io/vasp/vasprun.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from ramannoodle.dynamics.phonon import Phonons
from ramannoodle.dynamics.trajectory import Trajectory
from ramannoodle.structure.reference import ReferenceStructure
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore


def _get_root_element(file: TextIO) -> Element:
Expand Down Expand Up @@ -195,7 +199,7 @@ def read_structure_and_polarizability(

def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
) -> PolarizabilityDataset:
) -> dataset.PolarizabilityDataset:
"""Read polarizability dataset from OUTCAR files.
Parameters
Expand Down
2 changes: 2 additions & 0 deletions ramannoodle/polarizability/torch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ramannoodle.exceptions import verify_ndarray_shape, verify_list_len, get_type_error
import ramannoodle.polarizability.torch.utils as rn_torch_utils

TORCH_PRESENT = True


def _scale_and_flatten_polarizabilities(
polarizabilities: Tensor,
Expand Down
44 changes: 44 additions & 0 deletions ramannoodle/polarizability/torch/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Dummy polarizability PyTorch dataset.
Used when torch installation cannot be found.
:meta private:
"""

import numpy as np
from numpy.typing import NDArray

TORCH_PRESENT = False


class PolarizabilityDataset: # pylint: disable=too-few-public-methods
"""PyTorch dataset of atomic structures and polarizabilities.
Polarizabilities are scaled and flattened into vectors containing the six
independent tensor components.
Parameters
----------
lattices
| (Å) 3D array with shape (S,3,3) where S is the number of samples.
atomic_numbers
| List of length S containing lists of length N, where N is the number of atoms.
positions
| (fractional) 3D array with shape (S,N,3).
polarizabilities
| 3D array with shape (S,3,3).
scale_mode
| Supports ``"standard"`` (standard scaling), ``"stddev"`` (division by
| standard deviation), and ``"none"`` (no scaling).
"""

def __init__( # pylint: disable=too-many-arguments
self,
lattices: NDArray[np.float64],
atomic_numbers: list[list[int]],
positions: NDArray[np.float64],
polarizabilities: NDArray[np.float64],
scale_mode: str = "standard",
):
raise ModuleNotFoundError("torch installation not found")

0 comments on commit 1148a0e

Please sign in to comment.