From 138294ab8be63202473c7038f468649a78ad74c6 Mon Sep 17 00:00:00 2001 From: "Kieran B. Spooner" Date: Fri, 5 Jul 2024 17:29:54 +0100 Subject: [PATCH] bug fixes and wideband update now compatible with matplotlib 3.9 (retains past compatability) bandmin and bandmax now work with wideband, maintainability enchancements --- tests/test_plot/test_heatmap.py | 5 ++- tp/cli/cli.py | 38 ++++++++++++--------- tp/cli/options.py | 51 ++++++++++++++++++++++++++++- tp/plot/frequency.py | 5 ++- tp/plot/heatmap.py | 10 ++++-- tp/plot/phonons.py | 58 +++++++++++++++++++++++++-------- tp/plot/utilities.py | 5 ++- 7 files changed, 138 insertions(+), 34 deletions(-) diff --git a/tests/test_plot/test_heatmap.py b/tests/test_plot/test_heatmap.py index 18eb32b..35acb84 100644 --- a/tests/test_plot/test_heatmap.py +++ b/tests/test_plot/test_heatmap.py @@ -111,7 +111,10 @@ def test_max(self, mock_colourbar): @patch.object(plt, 'colorbar') def test_colourmap(self, mock_colourbar): - cmap = mpl.cm.get_cmap('viridis') + try: + cmap = mpl.cm.get_cmap('viridis') + except AttributeError: + cmap = mpl.colormaps['viridis'] cbar = heatmap.add_heatmap(self.ax, self.x, self.y, self.c, xinterp=None, yinterp=None, xscale='linear', yscale='linear', cscale='linear', diff --git a/tp/cli/cli.py b/tp/cli/cli.py index 9e5b25e..898e00b 100644 --- a/tp/cli/cli.py +++ b/tp/cli/cli.py @@ -879,7 +879,10 @@ def avg_rates(mesh_h5, mfp, total, x, crt, exclude, doping, direction, colour = colour[0] try: - colours = mpl.cm.get_cmap(colour)(np.linspace(0, 1, nlines)) + try: + colours = mpl.cm.get_cmap(colour)(np.linspace(0, 1, nlines)) + except AttributeError: + colours = mpl.colormaps[colour](np.linspace(0, 1, nlines)) colours = [c for c in colours] except ValueError: if isinstance(colour, str) and colour == 'skelton': @@ -1369,7 +1372,10 @@ def kappa(kfile, efile, component, direction, tmin, tmax, dtype, doping, legend_title = defleg['title'] try: - colours = mpl.cm.get_cmap(colour[0])(np.linspace(0, 1, len(data))) + try: + colours = mpl.cm.get_cmap(colour[0])(np.linspace(0, 1, len(data))) + except AttributeError: + colours = mpl.colormaps[colour[0]](np.linspace(0, 1, len(data))) colours = [c for c in colours] except ValueError: if isinstance(colour[0], str) and colour[0] == 'skelton': @@ -1513,12 +1519,7 @@ def kappa_target(transport_file, zt, direction, interpolate, kind, colour, @plot.command('phonons', no_args_is_help=True) @adminsitrative_options @inputs_function('band_yaml') -@click.option('--bandmin', - help='Minimum band index.', - type=click.IntRange(0)) -@click.option('--bandmax', - help='Maximum band index.', - type=click.IntRange(0)) +@bandrange_options @click.option('-c', '--colour', help='Colourmap name or min and max colours or list of ' @@ -1959,7 +1960,11 @@ def transport(transport_file, kfile, quantity, direction, tmin, tmax, dtype, q: edata2[q]}) try: - colours = mpl.cm.get_cmap(colour[0])(np.linspace(0, 1, lendata)) + try: + colours = mpl.cm.get_cmap(colour[0])(np.linspace(0, 1, lendata)) + except AttributeError: + colours = mpl.colormaps[colour[0]](np.linspace(0, 1, lendata)) + colours = [c for c in colours] except ValueError: if isinstance(colour[0], str) and colour[0] == 'skelton': @@ -2203,6 +2208,7 @@ def waterfall(kappa_hdf5, y, x, projected, direction, temperature, colour, alpha @adminsitrative_options @inputs_function('band_yaml', nargs=1) @inputs_function('kappa_hdf5', nargs=1) +@bandrange_options @temperature_option @click.option('-p', '--poscar', @@ -2228,9 +2234,9 @@ def waterfall(kappa_hdf5, y, x, projected, direction, temperature, colour, alpha @plot_io_function('tp-wideband') @verbose_option -def wideband(band_yaml, kappa_hdf5, temperature, poscar, colour, smoothing, - style, xmin, xmax, ymin, ymax, large, save, show, extension, - output, verbose): +def wideband(band_yaml, kappa_hdf5, bandmin, bandmax, temperature, poscar, + colour, smoothing, style, xmin, xmax, ymin, ymax, large, save, + show, extension, output, verbose): """Plots a broadened phonon dispersion.""" axes = tp.axes.large if large else tp.axes.small @@ -2244,9 +2250,11 @@ def wideband(band_yaml, kappa_hdf5, temperature, poscar, colour, smoothing, fig, ax, _ = axes.one(style) - tp.plot.phonons.add_wideband(ax, kdata, pdata, temperature=temperature, - poscar=poscar, smoothing=smoothing, - colour=colour, verbose=verbose) + tp.plot.phonons.add_wideband(ax, kdata, pdata, bandmin=bandmin, + bandmax=bandmax, ymin=ymin, ymax=ymax, + temperature=temperature, poscar=poscar, + smoothing=smoothing, colour=colour, + verbose=verbose) if xmin is not None: if xmax is not None: diff --git a/tp/cli/options.py b/tp/cli/options.py index 4b83bb2..da014d3 100644 --- a/tp/cli/options.py +++ b/tp/cli/options.py @@ -7,6 +7,8 @@ # function for help and version display. # axes_limit_function # function for setting the axis limits. +# bandrange_options +# option for range of bands plotted. # direction_function: # function for picking the --direction (-d). # doping_type_option: @@ -127,6 +129,33 @@ def axes_limit_options(f): return f return axes_limit_options +def bandrange_options(f): + """Options for specifying bands plotted. + + Options + ------- + + --bandmin : float, optional + minimum band index (0-indexed, inclusive). + --bandmax : float, optional + maximum band index (0-indexed, exclusive). + + Returns + ------- + + decorator + band range options decorator. + """ + + f = click.option('--bandmin', + help="Minimum band index.", + type=click.IntRange(0))(f) + f = click.option('--bandmax', + help="Maximum band index.", + type=click.IntRange(0))(f) + + return f + def direction_function(multiple=False): """Function to create direction options. @@ -334,7 +363,27 @@ def dos_options(f): return dos_options def heatmap_options(f): - """Options for heatmaps.""" + """Options for heatmaps. + + Options + ------- + + --discrete/--continuous : bool, optional + discretise the colourmap. Default: --continuous + -l, --levels : int or array-like, optional + boundaries for discrete plots. Lists specify actual values + while integers specify maximum-1 number of boundaries. + --contours : float or array-like, optional + contour line values. + --contourcolours: str or array-like, optional + contour colours. Default: black. + + Returns + ------- + + decorator + heatmap options decorator. + """ f = click.option('--discrete/--continuous', help='Discretise colourmap. [default: continuous]', diff --git a/tp/plot/frequency.py b/tp/plot/frequency.py index ceb5a04..9b5f0d6 100644 --- a/tp/plot/frequency.py +++ b/tp/plot/frequency.py @@ -702,7 +702,10 @@ def add_waterfall(ax, data, quantity, xquantity='frequency', temperature=300, s = np.shape(data[xquantity]) try: - colour = mpl.cm.get_cmap(colour) + try: + colour = mpl.cm.get_cmap(colour) + except AttributeError: + colour = mpl.colormaps[colour] colours = [colour(i) for i in np.linspace(0, 1, s[1])] colours = np.tile(colours, (s[0], 1)) except ValueError: diff --git a/tp/plot/heatmap.py b/tp/plot/heatmap.py index c541d1c..0d98733 100644 --- a/tp/plot/heatmap.py +++ b/tp/plot/heatmap.py @@ -190,7 +190,10 @@ def add_heatmap(ax, x, y, c, xinterp=None, yinterp=None, kind='linear', # #rrggbb colour as the highlight colour for a tp.plot.colour.uniform. try: - colours = copy(mpl.cm.get_cmap(colour)) + try: + colours = copy(mpl.cm.get_cmap(colour)) + except AttributeError: + colours = copy(mpl.colormaps[colour]) except ValueError: if isinstance(colour, mpl.colors.ListedColormap): colours = colour @@ -282,7 +285,10 @@ def add_heatmap(ax, x, y, c, xinterp=None, yinterp=None, kind='linear', cmap = None try: cmap = contourcolours - contourcolours = mpl.cm.get_cmap(contourcolours)(ctnorm) + try: + contourcolours = mpl.cm.get_cmap(contourcolours)(ctnorm) + except AttributeError: + contourcolours = mpl.colormaps[contourcolours](ctnorm) plt.contour(x[:-1], y[:-1], c, contours, cmap=cmap, **contourkwargs) except ValueError: diff --git a/tp/plot/phonons.py b/tp/plot/phonons.py index 0131f04..4033a5f 100644 --- a/tp/plot/phonons.py +++ b/tp/plot/phonons.py @@ -53,8 +53,8 @@ workers = tp.settings.get_workers() def add_dispersion(ax, data, sdata=None, bandmin=None, bandmax=None, main=True, - label=None, colour='#800080', linestyle='solid', - marker=None, xmarkkwargs={}, **kwargs): + label=None, colour='#800080', linestyle='solid', marker=None, + xmarkkwargs={}, **kwargs): """Adds a phonon band structure to a set of axes. Labels, colours and linestyles can be given one for the whole @@ -195,6 +195,8 @@ def add_dispersion(ax, data, sdata=None, bandmin=None, bandmax=None, main=True, if main: if round(np.amin(f), 1) == 0: ax.set_ylim(bottom=0) + else: + ax.set_ylim(bottom=np.amin(f)) if sdata is None: sdata = data formatting(ax, sdata, 'frequency', **xmarkkwargs) @@ -291,7 +293,10 @@ def add_multi(ax, data, bandmin=None, bandmax=None, main=True, label=None, # line appearance try: - colours = mpl.cm.get_cmap(colour)(np.linspace(0, 1, len(data))) + try: + colours = mpl.cm.get_cmap(colour)(np.linspace(0, 1, len(data))) + except AttributeError: + colours = mpl.colormaps[colour](np.linspace(0, 1, len(data))) colours = [[c] for c in colours] except ValueError: if isinstance(colour, mpl.colors.ListedColormap): @@ -812,6 +817,8 @@ def add_projected_dispersion(ax, data, pdata, quantity, bandmin=None, if main: if round(np.amin(f), 1) == 0: ax.set_ylim(bottom=0) + else: + ax.set_ylim(bottom=np.amin(f)) formatting(ax, pdata, 'frequency', **xmarkkwargs) return cbar @@ -1058,9 +1065,10 @@ def add_alt_projected_dispersion(ax, data, pdata, quantity, projected, return cbar @tp.docstring_replace(workers=str(workers)) -def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True, - smoothing=5, colour='viridis', workers=workers, - xmarkkwargs={}, verbose=False, **kwargs): +def add_wideband(ax, kdata, pdata, temperature=300, bandmin=None, bandmax=None, + ymin=None, ymax=None, poscar='POSCAR', main=True, smoothing=5, + colour='viridis', workers=workers, xmarkkwargs={}, + verbose=False, **kwargs): """Plots a phonon dispersion with broadened bands. Requires a POSCAR. @@ -1096,6 +1104,14 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True, temperature : float, optional approximate temperature in K (finds closest). Default: 300. + bandmin : int, optional + zero-indexed minimum band index to plot. Default: None. + bandmax : int, optional + zero-indexed maximum band index to plot. Default: None. + ymin : float, optional + minimum y-value plotted. + ymax : float, optional + maximum y-value plotted. poscar : str, optional VASP POSCAR filepath. Default: POSCAR. @@ -1180,6 +1196,19 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True, c = np.array(kdata['gamma']) qk = kdata['qpoint'] + # data selection + + if bandmin is None: + bandmin = 0 + else: + bandmin = np.amax([0, bandmin]) + if bandmax is None: + bandmax = len(pdata['frequency'][0]) + else: + bandmax = np.amin([len(pdata['frequency'][0]), bandmax]) + + c = c[:,bandmin:bandmax] + # Phonopy data formatting qp = pdata['qpoint'] @@ -1199,20 +1228,18 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True, min_id = pool.map(geq, qpi) c2 = c[min_id, :] - x, indices = np.unique(x, return_index=True) - f = np.array(pdata['frequency'])[indices] - where = np.where(c2 == np.amax(c2)) - # interpolate + x, indices = np.unique(x, return_index=True) + f = np.array(pdata['frequency'])[indices,bandmin:bandmax] x2 = np.linspace(min(x), max(x), 2500) finterp = interp1d(x, f, kind='cubic', axis=0) f = finterp(x2) cinterp = interp1d(xi, c2, kind='cubic', axis=0) c2 = np.abs(cinterp(x2)) - fmax = np.amax(np.add(f, c2)) - fmin = np.amin(np.subtract(f, c2)) + fmax = np.amax(np.add(f, c2)) if ymax is None else ymax + fmin = np.amin(np.subtract(f, c2)) if ymin is None else ymin c2 = np.where(c2==0, np.nanmin(c2[np.nonzero(c2)]), c2) f2 = np.linspace(fmin, fmax, 2500) @@ -1231,7 +1258,10 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True, # the min and max values. try: - cmap = mpl.cm.get_cmap(colour) + try: + cmap = mpl.cm.get_cmap(colour) + except AttributeError: + cmap = mpl.colormaps[colour] except ValueError: if isinstance(colour, mpl.colors.ListedColormap): cmap = colour @@ -1256,6 +1286,8 @@ def add_wideband(ax, kdata, pdata, temperature=300, poscar='POSCAR', main=True, if main: if round(np.amin(f), 1) == 0: ax.set_ylim(bottom=0) + else: + ax.set_ylim(bottom=fmin) formatting(ax, pdata, 'frequency', **xmarkkwargs) return diff --git a/tp/plot/utilities.py b/tp/plot/utilities.py index 16539d0..13886e5 100644 --- a/tp/plot/utilities.py +++ b/tp/plot/utilities.py @@ -141,7 +141,10 @@ def parse_colours(colour): from copy import copy try: - cmap = copy(mpl.cm.get_cmap(colour)) + try: + cmap = copy(mpl.cm.get_cmap(colour)) + except AttributeError: + cmap = copy(mpl.colormaps[colour]) except ValueError: if isinstance(colour, mpl.colors.ListedColormap): cmap = copy(colour)