diff --git a/tests/test_plot/test_phonons.py b/tests/test_plot/test_phonons.py index 04be3a6..5476878 100644 --- a/tests/test_plot/test_phonons.py +++ b/tests/test_plot/test_phonons.py @@ -117,7 +117,7 @@ def test_bandmin(self): self.ax.set_xlabel.assert_called_once() self.ax.set_ylabel.assert_called_once() self.ax.set_xlim.assert_called_once_with(0, 1) - self.ax.set_ylim.assert_not_called() + self.ax.set_ylim.assert_called_once_with(bottom=0) def test_bandmax(self): self.ax.spines['bottom'].get_linewidth().return_value = 1 diff --git a/tp/plot/phonons.py b/tp/plot/phonons.py index 2a4e639..0131f04 100644 --- a/tp/plot/phonons.py +++ b/tp/plot/phonons.py @@ -346,11 +346,15 @@ def add_multi(ax, data, bandmin=None, bandmax=None, main=True, label=None, else: bandmax = np.amin([len(data[0]['frequency'][0]), bandmax]) - f = [d['frequency'] for d in data] - f = np.array(f)[:,:,bandmin:bandmax] - - if round(np.amin(f), 1) == 0: - ax.set_ylim(bottom=0) + if bandmin < 3: + f = [d['frequency'] for d in data] + noim = True + for ff in f: + ff = np.array(ff)[:,bandmin:bandmax] + if round(np.amin(ff), 1) < 0: + noim = False + if noim == True: + ax.set_ylim(bottom=0) formatting(ax, data[0], 'frequency', **xmarkkwargs) return