Skip to content

Commit

Permalink
Merge pull request #104 from SMTG-Bham/bug-fix
Browse files Browse the repository at this point in the history
bug fixes and wideband update
  • Loading branch information
kbspooner committed Jul 5, 2024
2 parents 754f7b8 + 138294a commit d6cda08
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 34 deletions.
5 changes: 4 additions & 1 deletion tests/test_plot/test_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def test_max(self, mock_colourbar):

@patch.object(plt, 'colorbar')
def test_colourmap(self, mock_colourbar):
cmap = mpl.cm.get_cmap('viridis')
try:
cmap = mpl.cm.get_cmap('viridis')
except AttributeError:
cmap = mpl.colormaps['viridis']
cbar = heatmap.add_heatmap(self.ax, self.x, self.y, self.c,
xinterp=None, yinterp=None, xscale='linear',
yscale='linear', cscale='linear',
Expand Down
38 changes: 23 additions & 15 deletions tp/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,10 @@ def avg_rates(mesh_h5, mfp, total, x, crt, exclude, doping, direction,
colour = colour[0]

try:
colours = mpl.cm.get_cmap(colour)(np.linspace(0, 1, nlines))
try:
colours = mpl.cm.get_cmap(colour)(np.linspace(0, 1, nlines))
except AttributeError:
colours = mpl.colormaps[colour](np.linspace(0, 1, nlines))
colours = [c for c in colours]
except ValueError:
if isinstance(colour, str) and colour == 'skelton':
Expand Down Expand Up @@ -1369,7 +1372,10 @@ def kappa(kfile, efile, component, direction, tmin, tmax, dtype, doping,
legend_title = defleg['title']

try:
colours = mpl.cm.get_cmap(colour[0])(np.linspace(0, 1, len(data)))
try:
colours = mpl.cm.get_cmap(colour[0])(np.linspace(0, 1, len(data)))
except AttributeError:
colours = mpl.colormaps[colour[0]](np.linspace(0, 1, len(data)))
colours = [c for c in colours]
except ValueError:
if isinstance(colour[0], str) and colour[0] == 'skelton':
Expand Down Expand Up @@ -1513,12 +1519,7 @@ def kappa_target(transport_file, zt, direction, interpolate, kind, colour,
@plot.command('phonons', no_args_is_help=True)
@adminsitrative_options
@inputs_function('band_yaml')
@click.option('--bandmin',
help='Minimum band index.',
type=click.IntRange(0))
@click.option('--bandmax',
help='Maximum band index.',
type=click.IntRange(0))
@bandrange_options

@click.option('-c', '--colour',
help='Colourmap name or min and max colours or list of '
Expand Down Expand Up @@ -1959,7 +1960,11 @@ def transport(transport_file, kfile, quantity, direction, tmin, tmax, dtype,
q: edata2[q]})

try:
colours = mpl.cm.get_cmap(colour[0])(np.linspace(0, 1, lendata))
try:
colours = mpl.cm.get_cmap(colour[0])(np.linspace(0, 1, lendata))
except AttributeError:
colours = mpl.colormaps[colour[0]](np.linspace(0, 1, lendata))

colours = [c for c in colours]
except ValueError:
if isinstance(colour[0], str) and colour[0] == 'skelton':
Expand Down Expand Up @@ -2203,6 +2208,7 @@ def waterfall(kappa_hdf5, y, x, projected, direction, temperature, colour, alpha
@adminsitrative_options
@inputs_function('band_yaml', nargs=1)
@inputs_function('kappa_hdf5', nargs=1)
@bandrange_options

@temperature_option
@click.option('-p', '--poscar',
Expand All @@ -2228,9 +2234,9 @@ def waterfall(kappa_hdf5, y, x, projected, direction, temperature, colour, alpha
@plot_io_function('tp-wideband')
@verbose_option

def wideband(band_yaml, kappa_hdf5, temperature, poscar, colour, smoothing,
style, xmin, xmax, ymin, ymax, large, save, show, extension,
output, verbose):
def wideband(band_yaml, kappa_hdf5, bandmin, bandmax, temperature, poscar,
colour, smoothing, style, xmin, xmax, ymin, ymax, large, save,
show, extension, output, verbose):
"""Plots a broadened phonon dispersion."""

axes = tp.axes.large if large else tp.axes.small
Expand All @@ -2244,9 +2250,11 @@ def wideband(band_yaml, kappa_hdf5, temperature, poscar, colour, smoothing,

fig, ax, _ = axes.one(style)

tp.plot.phonons.add_wideband(ax, kdata, pdata, temperature=temperature,
poscar=poscar, smoothing=smoothing,
colour=colour, verbose=verbose)
tp.plot.phonons.add_wideband(ax, kdata, pdata, bandmin=bandmin,
bandmax=bandmax, ymin=ymin, ymax=ymax,
temperature=temperature, poscar=poscar,
smoothing=smoothing, colour=colour,
verbose=verbose)

if xmin is not None:
if xmax is not None:
Expand Down
51 changes: 50 additions & 1 deletion tp/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# function for help and version display.
# axes_limit_function
# function for setting the axis limits.
# bandrange_options
# option for range of bands plotted.
# direction_function:
# function for picking the --direction (-d).
# doping_type_option:
Expand Down Expand Up @@ -127,6 +129,33 @@ def axes_limit_options(f):
return f
return axes_limit_options

def bandrange_options(f):
"""Options for specifying bands plotted.
Options
-------
--bandmin : float, optional
minimum band index (0-indexed, inclusive).
--bandmax : float, optional
maximum band index (0-indexed, exclusive).
Returns
-------
decorator
band range options decorator.
"""

f = click.option('--bandmin',
help="Minimum band index.",
type=click.IntRange(0))(f)
f = click.option('--bandmax',
help="Maximum band index.",
type=click.IntRange(0))(f)

return f

def direction_function(multiple=False):
"""Function to create direction options.
Expand Down Expand Up @@ -334,7 +363,27 @@ def dos_options(f):
return dos_options

def heatmap_options(f):
"""Options for heatmaps."""
"""Options for heatmaps.
Options
-------
--discrete/--continuous : bool, optional
discretise the colourmap. Default: --continuous
-l, --levels : int or array-like, optional
boundaries for discrete plots. Lists specify actual values
while integers specify maximum-1 number of boundaries.
--contours : float or array-like, optional
contour line values.
--contourcolours: str or array-like, optional
contour colours. Default: black.
Returns
-------
decorator
heatmap options decorator.
"""

f = click.option('--discrete/--continuous',
help='Discretise colourmap. [default: continuous]',
Expand Down
5 changes: 4 additions & 1 deletion tp/plot/frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,10 @@ def add_waterfall(ax, data, quantity, xquantity='frequency', temperature=300,

s = np.shape(data[xquantity])
try:
colour = mpl.cm.get_cmap(colour)
try:
colour = mpl.cm.get_cmap(colour)
except AttributeError:
colour = mpl.colormaps[colour]
colours = [colour(i) for i in np.linspace(0, 1, s[1])]
colours = np.tile(colours, (s[0], 1))
except ValueError:
Expand Down
10 changes: 8 additions & 2 deletions tp/plot/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def add_heatmap(ax, x, y, c, xinterp=None, yinterp=None, kind='linear',
# #rrggbb colour as the highlight colour for a tp.plot.colour.uniform.

try:
colours = copy(mpl.cm.get_cmap(colour))
try:
colours = copy(mpl.cm.get_cmap(colour))
except AttributeError:
colours = copy(mpl.colormaps[colour])
except ValueError:
if isinstance(colour, mpl.colors.ListedColormap):
colours = colour
Expand Down Expand Up @@ -282,7 +285,10 @@ def add_heatmap(ax, x, y, c, xinterp=None, yinterp=None, kind='linear',
cmap = None
try:
cmap = contourcolours
contourcolours = mpl.cm.get_cmap(contourcolours)(ctnorm)
try:
contourcolours = mpl.cm.get_cmap(contourcolours)(ctnorm)
except AttributeError:
contourcolours = mpl.colormaps[contourcolours](ctnorm)
plt.contour(x[:-1], y[:-1], c, contours, cmap=cmap,
**contourkwargs)
except ValueError:
Expand Down
58 changes: 45 additions & 13 deletions tp/plot/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
workers = tp.settings.get_workers()

def add_dispersion(ax, data, sdata=None, bandmin=None, bandmax=None, main=True,
label=None, colour='#800080', linestyle='solid',
marker=None, xmarkkwargs={}, **kwargs):
label=None, colour='#800080', linestyle='solid', marker=None,
xmarkkwargs={}, **kwargs):
"""Adds a phonon band structure to a set of axes.
Labels, colours and linestyles can be given one for the whole
Expand Down Expand Up @@ -195,6 +195,8 @@ def add_dispersion(ax, data, sdata=None, bandmin=None, bandmax=None, main=True,
if main:
if round(np.amin(f), 1) == 0:
ax.set_ylim(bottom=0)
else:
ax.set_ylim(bottom=np.amin(f))
if sdata is None: sdata = data
formatting(ax, sdata, 'frequency', **xmarkkwargs)

Expand Down Expand Up @@ -291,7 +293,10 @@ def add_multi(ax, data, bandmin=None, bandmax=None, main=True, label=None,
# line appearance

try:
colours = mpl.cm.get_cmap(colour)(np.linspace(0, 1, len(data)))
try:
colours = mpl.cm.get_cmap(colour)(np.linspace(0, 1, len(data)))
except AttributeError:
colours = mpl.colormaps[colour](np.linspace(0, 1, len(data)))
colours = [[c] for c in colours]
except ValueError:
if isinstance(colour, mpl.colors.ListedColormap):
Expand Down Expand Up @@ -812,6 +817,8 @@ def add_projected_dispersion(ax, data, pdata, quantity, bandmin=None,
if main:
if round(np.amin(f), 1) == 0:
ax.set_ylim(bottom=0)
else:
ax.set_ylim(bottom=np.amin(f))
formatting(ax, pdata, 'frequency', **xmarkkwargs)

return cbar
Expand Down Expand Up @@ -1058,9 +1065,10 @@ def add_alt_projected_dispersion(ax, data, pdata, quantity, projected,
return cbar

@tp.docstring_replace(workers=str(workers))
def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True,
smoothing=5, colour='viridis', workers=workers,
xmarkkwargs={}, verbose=False, **kwargs):
def add_wideband(ax, kdata, pdata, temperature=300, bandmin=None, bandmax=None,
ymin=None, ymax=None, poscar='POSCAR', main=True, smoothing=5,
colour='viridis', workers=workers, xmarkkwargs={},
verbose=False, **kwargs):
"""Plots a phonon dispersion with broadened bands.
Requires a POSCAR.
Expand Down Expand Up @@ -1096,6 +1104,14 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True,
temperature : float, optional
approximate temperature in K (finds closest). Default: 300.
bandmin : int, optional
zero-indexed minimum band index to plot. Default: None.
bandmax : int, optional
zero-indexed maximum band index to plot. Default: None.
ymin : float, optional
minimum y-value plotted.
ymax : float, optional
maximum y-value plotted.
poscar : str, optional
VASP POSCAR filepath. Default: POSCAR.
Expand Down Expand Up @@ -1180,6 +1196,19 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True,
c = np.array(kdata['gamma'])
qk = kdata['qpoint']

# data selection

if bandmin is None:
bandmin = 0
else:
bandmin = np.amax([0, bandmin])
if bandmax is None:
bandmax = len(pdata['frequency'][0])
else:
bandmax = np.amin([len(pdata['frequency'][0]), bandmax])

c = c[:,bandmin:bandmax]

# Phonopy data formatting

qp = pdata['qpoint']
Expand All @@ -1199,20 +1228,18 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True,
min_id = pool.map(geq, qpi)
c2 = c[min_id, :]

x, indices = np.unique(x, return_index=True)
f = np.array(pdata['frequency'])[indices]
where = np.where(c2 == np.amax(c2))

# interpolate

x, indices = np.unique(x, return_index=True)
f = np.array(pdata['frequency'])[indices,bandmin:bandmax]
x2 = np.linspace(min(x), max(x), 2500)
finterp = interp1d(x, f, kind='cubic', axis=0)
f = finterp(x2)

cinterp = interp1d(xi, c2, kind='cubic', axis=0)
c2 = np.abs(cinterp(x2))
fmax = np.amax(np.add(f, c2))
fmin = np.amin(np.subtract(f, c2))
fmax = np.amax(np.add(f, c2)) if ymax is None else ymax
fmin = np.amin(np.subtract(f, c2)) if ymin is None else ymin
c2 = np.where(c2==0, np.nanmin(c2[np.nonzero(c2)]), c2)
f2 = np.linspace(fmin, fmax, 2500)

Expand All @@ -1231,7 +1258,10 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True,
# the min and max values.

try:
cmap = mpl.cm.get_cmap(colour)
try:
cmap = mpl.cm.get_cmap(colour)
except AttributeError:
cmap = mpl.colormaps[colour]
except ValueError:
if isinstance(colour, mpl.colors.ListedColormap):
cmap = colour
Expand All @@ -1256,6 +1286,8 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True,
if main:
if round(np.amin(f), 1) == 0:
ax.set_ylim(bottom=0)
else:
ax.set_ylim(bottom=fmin)
formatting(ax, pdata, 'frequency', **xmarkkwargs)

return
Expand Down
5 changes: 4 additions & 1 deletion tp/plot/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def parse_colours(colour):
from copy import copy

try:
cmap = copy(mpl.cm.get_cmap(colour))
try:
cmap = copy(mpl.cm.get_cmap(colour))
except AttributeError:
cmap = copy(mpl.colormaps[colour])
except ValueError:
if isinstance(colour, mpl.colors.ListedColormap):
cmap = copy(colour)
Expand Down

0 comments on commit d6cda08

Please sign in to comment.