Skip to content

Commit

Permalink
Merge branch '10-selection-refactoring' into 'master'
Browse files Browse the repository at this point in the history
Resolve "Selection refactoring"

Closes orest-d#10

See merge request schlipf/py4vasp!10
  • Loading branch information
martin-schlipf committed Jan 31, 2020
2 parents 2f20538 + 592871b commit db9d2ea
Show file tree
Hide file tree
Showing 10 changed files with 464 additions and 196 deletions.
1 change: 1 addition & 0 deletions src/py4vasp/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .band import Band
from .dos import Dos
from .projectors import Projectors

import plotly.io as pio
import cufflinks as cf
Expand Down
87 changes: 12 additions & 75 deletions src/py4vasp/data/band.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import re
import functools
import itertools
import numpy as np
import plotly.graph_objects as go
from collections import namedtuple
from .projectors import Projectors


class Band:
_Index = namedtuple("_Index", "spin, atom, orbital")
_Atom = namedtuple("_Atom", "indices, label")
_Orbital = namedtuple("_Orbital", "indices, label")
_Spin = namedtuple("_Spin", "indices, label")

def __init__(self, raw_band):
self._raw = raw_band
self._fermi_energy = raw_band.fermi_energy
Expand All @@ -25,58 +19,14 @@ def __init__(self, raw_band):
self._num_lines = len(self._kpoints) // self._line_length
self._indices = raw_band.label_indices
self._labels = raw_band.labels
self._has_projectors = raw_band.projectors is not None
if self._has_projectors:
self._init_projectors(raw_band.projectors)
if raw_band.projectors is not None:
self._projectors = Projectors(raw_band.projectors)
self._projections = raw_band.projections

@classmethod
def from_file(cls, file):
return cls(file.band())

def _init_projectors(self, raw_proj):
self._projections = raw_proj.bands
ion_types = raw_proj.ion_types
ion_types = [type.decode().strip() for type in ion_types]
self._init_atom_dict(ion_types, raw_proj.number_ion_types)
orbitals = raw_proj.orbital_types
orbitals = [orb.decode().strip() for orb in orbitals]
self._init_orbital_dict(orbitals)
self._init_spin_dict()

def _init_atom_dict(self, ion_types, number_ion_types):
num_atoms = self._projections.shape[1]
all_atoms = self._Atom(indices=range(num_atoms), label=None)
self._atom_dict = {"*": all_atoms}
start = 0
for type, number in zip(ion_types, number_ion_types):
_range = range(start, start + number)
self._atom_dict[type] = self._Atom(indices=_range, label=type)
for i in _range:
# create labels like Si_1, Si_2, Si_3 (starting at 1)
label = type + "_" + str(_range.index(i) + 1)
self._atom_dict[str(i + 1)] = self._Atom(indices=[i], label=label)
start += number
# atoms may be preceeded by :
for key in self._atom_dict.copy():
self._atom_dict[key + ":"] = self._atom_dict[key]

def _init_orbital_dict(self, orbitals):
num_orbitals = self._projections.shape[2]
all_orbitals = self._Orbital(indices=range(num_orbitals), label=None)
self._orbital_dict = {"*": all_orbitals}
for i, orbital in enumerate(orbitals):
self._orbital_dict[orbital] = self._Orbital(indices=[i], label=orbital)
if "px" in self._orbital_dict:
self._orbital_dict["p"] = self._Orbital(indices=range(1, 4), label="p")
self._orbital_dict["d"] = self._Orbital(indices=range(4, 9), label="d")
self._orbital_dict["f"] = self._Orbital(indices=range(9, 16), label="f")

def _init_spin_dict(self):
labels = ["up", "down"] if self._spin_polarized else [None]
self._spin_dict = {
key: self._Spin(indices=[i], label=key) for i, key in enumerate(labels)
}

def read(self, selection=None):
kpoints = self._kpoints[:]
return {
Expand Down Expand Up @@ -124,12 +74,6 @@ def _shift_bands_by_fermi_energy(self):
else:
return {"bands": self._bands[0] - self._fermi_energy}

def _read_projections(self, selection):
if selection is None:
return {}
parts = self._parse_selection(selection)
return self._read_elements(parts)

def _scatter(self, name, kdists, lines):
# insert NaN to split separate lines
num_bands = lines.shape[-1]
Expand All @@ -146,24 +90,17 @@ def _kpoint_distances(self, kpoints):
)
return functools.reduce(concatenate_distances, kpoint_norms)

def _parse_selection(self, selection):
atom = self._atom_dict["*"]
selection = re.sub("\s*:\s*", ": ", selection)
for part in re.split("[ ,]+", selection):
if part in self._orbital_dict:
orbital = self._orbital_dict[part]
else:
atom = self._atom_dict[part]
orbital = self._orbital_dict["*"]
if ":" not in part: # exclude ":" because it starts a new atom
for spin in self._spin_dict.values():
yield atom, orbital, spin
def _read_projections(self, selection):
if selection is None:
return {}
return self._read_elements(selection)

def _read_elements(self, parts):
def _read_elements(self, selection):
res = {}
for atom, orbital, spin in parts:
for select in self._projectors.parse_selection(selection):
atom, orbital, spin = self._projectors.select(*select)
label = self._merge_labels([atom.label, orbital.label, spin.label])
index = self._Index(spin.indices, atom.indices, orbital.indices)
index = (spin.indices, atom.indices, orbital.indices)
res[label] = self._read_element(index)
return res

Expand Down
80 changes: 9 additions & 71 deletions src/py4vasp/data/dos.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import re
import functools
import itertools
import numpy as np
import pandas as pd
from collections import namedtuple
from .projectors import Projectors


class Dos:
_Index = namedtuple("_Index", "spin, atom, orbital")
_Atom = namedtuple("_Atom", "indices, label")
_Orbital = namedtuple("_Orbital", "indices, label")
_Spin = namedtuple("_Spin", "indices, label")

def __init__(self, raw_dos):
self._raw = raw_dos
self._fermi_energy = raw_dos.fermi_energy
Expand All @@ -20,56 +14,13 @@ def __init__(self, raw_dos):
self._spin_polarized = self._dos.shape[0] == 2
self._has_partial_dos = raw_dos.projectors is not None
if self._has_partial_dos:
self._init_partial_dos(raw_dos.projectors)
self._projectors = Projectors(raw_dos.projectors)
self._projections = raw_dos.projections

@classmethod
def from_file(cls, file):
return cls(file.dos())

def _init_partial_dos(self, raw_proj):
self._partial_dos = raw_proj.dos
ion_types = raw_proj.ion_types
ion_types = [type.decode().strip() for type in ion_types]
self._init_atom_dict(ion_types, raw_proj.number_ion_types)
orbitals = raw_proj.orbital_types
orbitals = [orb.decode().strip() for orb in orbitals]
self._init_orbital_dict(orbitals)
self._init_spin_dict()

def _init_atom_dict(self, ion_types, number_ion_types):
num_atoms = self._partial_dos.shape[1]
all_atoms = self._Atom(indices=range(num_atoms), label=None)
self._atom_dict = {"*": all_atoms}
start = 0
for type, number in zip(ion_types, number_ion_types):
_range = range(start, start + number)
self._atom_dict[type] = self._Atom(indices=_range, label=type)
for i in _range:
# create labels like Si_1, Si_2, Si_3 (starting at 1)
label = type + "_" + str(_range.index(i) + 1)
self._atom_dict[str(i + 1)] = self._Atom(indices=[i], label=label)
start += number
# atoms may be preceeded by :
for key in self._atom_dict.copy():
self._atom_dict[key + ":"] = self._atom_dict[key]

def _init_orbital_dict(self, orbitals):
num_orbitals = self._partial_dos.shape[2]
all_orbitals = self._Orbital(indices=range(num_orbitals), label=None)
self._orbital_dict = {"*": all_orbitals}
for i, orbital in enumerate(orbitals):
self._orbital_dict[orbital] = self._Orbital(indices=[i], label=orbital)
if "px" in self._orbital_dict:
self._orbital_dict["p"] = self._Orbital(indices=range(1, 4), label="p")
self._orbital_dict["d"] = self._Orbital(indices=range(4, 9), label="d")
self._orbital_dict["f"] = self._Orbital(indices=range(9, 16), label="f")

def _init_spin_dict(self):
labels = ["up", "down"] if self._spin_polarized else [None]
self._spin_dict = {
key: self._Spin(indices=[i], label=key) for i, key in enumerate(labels)
}

def plot(self, selection=None):
df = self.to_frame(selection)
if self._spin_polarized:
Expand Down Expand Up @@ -114,40 +65,27 @@ def _read_partial_dos(self, selection):
if selection is None:
return {}
self._raise_error_if_partial_Dos_not_available()
parts = self._parse_filter(selection)
return self._read_elements(parts)
return self._read_elements(selection)

def _raise_error_if_partial_Dos_not_available(self):
if not self._has_partial_dos:
raise ValueError(
"Filtering requires partial DOS which was not found in HDF5 file."
)

def _parse_filter(self, selection):
atom = self._atom_dict["*"]
selection = re.sub("\s*:\s*", ": ", selection)
for part in re.split("[ ,]+", selection):
if part in self._orbital_dict:
orbital = self._orbital_dict[part]
else:
atom = self._atom_dict[part]
orbital = self._orbital_dict["*"]
if ":" not in part: # exclude ":" because it starts a new atom
for spin in self._spin_dict.values():
yield atom, orbital, spin

def _read_elements(self, parts):
def _read_elements(self, selection):
res = {}
for atom, orbital, spin in parts:
for select in self._projectors.parse_selection(selection):
atom, orbital, spin = self._projectors.select(*select)
label = self._merge_labels([atom.label, orbital.label, spin.label])
index = self._Index(spin.indices, atom.indices, orbital.indices)
index = (spin.indices, atom.indices, orbital.indices)
res[label] = self._read_element(index)
return res

def _merge_labels(self, labels):
return "_".join(filter(None, labels))

def _read_element(self, index):
sum_dos = lambda dos, i: dos + self._partial_dos[i]
sum_dos = lambda dos, i: dos + self._projections[i]
zero_dos = np.zeros(len(self._energies))
return functools.reduce(sum_dos, itertools.product(*index), zero_dos)
Loading

0 comments on commit db9d2ea

Please sign in to comment.