Skip to content

Commit

Permalink
Merge pull request #16 from moldyn/updated-plotting
Browse files Browse the repository at this point in the history
Update plotting
  • Loading branch information
dieJaegerIn authored Jul 2, 2024
2 parents 89d6494 + 830d847 commit b2f5c53
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 81 deletions.
77 changes: 47 additions & 30 deletions docs/tutorials/work.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/dcTMD/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dcTMD.dcTMD import WorkEstimator, ForceEstimator
from dcTMD.io import load_pullf, write_output
from dcTMD.storing import save
from dcTMD.utils import plotting
import matplotlib.pyplot as plt

MODES = ('work', 'force')

Expand Down Expand Up @@ -174,6 +176,12 @@ def main( # noqa: WPS211, WPS216
outname = f'{outname}_{mode}'
write_output(outname, estimator)

if plot:
plotting.plot_dcTMD_results(
estimator,
)
plt.savefig(f'{outname}.png')


if __name__ == '__main__':
main() # pragma: no cover
143 changes: 114 additions & 29 deletions src/dcTMD/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.stats import probplot, norm


"""
def fig_sizeA4width():
# Convert cm to inches
# A4 width
Expand All @@ -35,28 +36,50 @@ def fig_sizehalfA4width():
fig_height = fig_height_cm * inches_per_cm
fig_size = (fig_width, fig_height)
return fig_size
"""


def plot_dcTMD_results(x, workestimator, friction):
"""Plot dcTMD results overview in two subplots."""
def plot_dcTMD_results(
estimator,
friction=None,
x=None,
figsize=(4, 4),
):
"""Plot dcTMD results overview in two subplots.
Top subplot constains free energy, dissipative work and mean work.
Bottom subplot contains friction vs. position.
"""
fig, axs = plt.subplots(ncols=1,
nrows=2,
sharex=True,
figsize=fig_sizehalfA4width(),
figsize=figsize,
)
plot_dG_Wdiss(x, workestimator, axs[0])
if x is None:
x = estimator.position_
plot_dG_Wdiss(estimator, axs[0], x=x)
if friction is None:
friction = estimator.friction_
if hasattr(estimator, 'friction_smooth_'):
friction = estimator.friction_smooth_
plot_Gamma(x, friction, axs[1])
axs[0].legend(loc='lower left', mode='expand',
bbox_to_anchor=(0, 0.9, 1, 0.2),
frameon=False,
ncol=3,)
axs[0].legend(
loc='lower left',
mode='expand',
bbox_to_anchor=(0, 1.05, 1, 0.2), # Adjust top margin
bbox_transform=axs[0].transAxes, # Use axis coordinates
frameon=False,
ncol=3,
)
axs[0].set_xlabel("")
plt.tight_layout()
return fig, axs


def plot_dG_Wdiss(x, workestimator, ax):
"""Plot free energy, dissipative work and mean work against position."""
def plot_dG_Wdiss(workestimator, ax, x=None):
"""Plot free energy, dissipative work and mean work vs position."""
if x is None:
x = workestimator.position_
ax.plot(x, workestimator.dG_, label=r'$\Delta G$')
ax.plot(x, workestimator.W_mean_, label=r'W$_{\mathrm{mean}}$')
ax.plot(x, workestimator.W_diss_, label=r'W$_{\mathrm{diss}}$')
Expand All @@ -67,27 +90,78 @@ def plot_dG_Wdiss(x, workestimator, ax):


def plot_Gamma(x, friction, ax, label=None):
"""Plot friction factor against position."""
ax.plot(x, friction, label=rf"{label}")
"""Plot friction factor vs position."""
if label is None:
ax.plot(x, friction)
else:
ax.plot(x, friction, label=label)
ax.set(xlabel=r'position $x$ [nm]',
ylabel=r'$\Gamma$ [kJ/mol/(nm$^2$/ps)]',
ylabel=r'$\Gamma$ [kJ nm$^2$/(mol ps)]',
xlim=[min(x), max(x)],
)


def plot_dG(x, dG, ax, label=None):
"""Plot free energy against position."""
ax.plot(x, dG, label=rf"{label}")
"""Plot free energy vs position."""
if label is None:
line, = ax.plot(x, dG)
line, = ax.plot(x, dG, label=label)
ax.set(xlabel=r'position $x$ [nm]',
ylabel=r'$\Delta G$ [kJ/mol]',
xlim=[min(x), max(x)],
)
return line


def plot_dG_werrors(workestimator, ax, labeldG=None):
"""Plot free energy with errors against position."""
if hasattr(workestimator, 's_dG_'):
x = workestimator.position_
dG = workestimator.dG_
sdG = workestimator.s_dG_
line = plot_dG(x, dG, ax, label=labeldG)
color = line.get_color()
if len(sdG) == 2:
ax.plot(
x,
sdG[0],
facecolor=color,
ls='dotted',
)
ax.plot(
x,
sdG[1],
facecolor=color,
ls='dotted',
)
ax.fill_between(
x,
sdG[0],
sdG[1],
facecolor=color,
alpha=0.3
)
else:
ax.fill_between(
x,
dG - sdG,
dG + sdG,
facecolor=color,
alpha=0.3
)
else:
print(f'no errors are determined for {workestimator}')
print('use estimate_free_energy_errors() to determine errors')
return


def plot_worklines(x, workset, ax):
"""Plot work of trajectories individually."""
for w in workset:
ax.plot(x, w, color='#777', alpha=.5, lw=.5)
def plot_worklines(workset, ax, x=None, color='#777', res=1):
"""Line plots of work of the individual trajectories
in the workset."""
if x is None:
x = workset.position_
for w in workset.work_:
ax.plot(x[::res], w[::res], color=color, alpha=.3, lw=.5)

ax.set(xlabel=r'position $x$ [nm]',
ylabel=r'work $W$ [kJ/mol]',
Expand All @@ -96,6 +170,8 @@ def plot_worklines(x, workset, ax):


def plot_histo_normaldist(data, ax, color='blue', label=None):
"""Plots a histogram of the data and
a normal distribution fitted to the data."""
data = data.flatten()
ax.hist(data,
bins='fd',
Expand All @@ -105,7 +181,7 @@ def plot_histo_normaldist(data, ax, color='blue', label=None):
alpha=0.5,
orientation='horizontal',
color=color,
label=rf'{label}',
label=label,
ec=color,
zorder=3,
)
Expand All @@ -122,32 +198,41 @@ def plot_histo_normaldist(data, ax, color='blue', label=None):
)


def plot_worknormalitychecks(x, workset, index, colors=None):
def plot_worknormalitychecks(
workset,
index,
x=None,
colors=None,
figsize=(2, 6),
):
"""Plots the work values of trajectories individually.
Also adds histograms and normality plots for the indices given in `index`.
"""
fig, axs = plt.subplots(ncols=3,
nrows=1,
figsize=fig_sizeA4width()
)
plot_worklines(x, workset, axs[0])
fig, axs = plt.subplots(
ncols=3,
nrows=1,
figsize=figsize,
)
if x is None:
x = workset.position_
plot_worklines(workset, axs[0], x=x)

if not colors:
cmap = plt.get_cmap('Dark2')
colors = cmap.colors

for j, idx in enumerate(index):
data = workset[:, idx].flatten()
work = workset.work_[:, idx].flatten()
axs[1].set_title(r'Histogram at $x$')
plot_histo_normaldist(data, axs[1], colors[j])
plot_histo_normaldist(work, axs[1], colors[j])
axs[0].axvline(x[idx],
color=colors[j],
zorder=3,
label=rf'$x={x[idx]}$nm',
)

probplot(data, plot=axs[2], fit=True)
probplot(work, plot=axs[2], fit=True)
axs[2].get_lines()[j * 2].set_color(colors[j])
axs[2].set_title('Normality plot')

Expand Down
5 changes: 3 additions & 2 deletions tests/create_testdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
n_resamples = 100
seed = 42
sigma = 0.1
mode = 'nearest'
bootstrapmode = 'std'

pullf_files = 'testdata/pullf_filenames.dat'
Expand All @@ -33,7 +34,7 @@
forceestimator = ForceEstimator(temperature)
forceestimator.fit(forceset)
# smooth friction
forceestimator.smooth_friction(sigma)
forceestimator.smooth_friction(sigma, mode=mode)
save('testdata/forceestimator', forceestimator)

# create ForceSet instance
Expand All @@ -50,7 +51,7 @@
workeestimator = WorkEstimator(temperature)
workeestimator.fit(workset)
# smooth friction
workeestimator.smooth_friction(sigma)
workeestimator.smooth_friction(sigma, mode=mode)
# error estimation vis bootstrapping
workeestimator.estimate_free_energy_errors(n_resamples, bootstrapmode, seed)
save('testdata/workeestimator', workeestimator)
88 changes: 84 additions & 4 deletions tests/test___main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
TEST_FILE_DIR = join(HERE, 'testdata')


def test_main(tmpdir):
"""Test main CLI."""
def test_main_work_mode(tmpdir):
"""Test main CLI in work mode."""
# Create temporary directory for output files
output = tmpdir.join('test')
output = tmpdir.join('test_work')
# Set up test command line arguments
args = [
'--mode', 'work',
'--file', f'{TEST_FILE_DIR}/*.xvg',
'--outname', f'{output}',
'--outname', str(output),
'--temperature', '300',
'--velocity', '0.01',
'--res', '1',
Expand All @@ -33,3 +33,83 @@ def test_main(tmpdir):
runner = CliRunner()
clirunner_result = runner.invoke(main, args)
assert clirunner_result.exit_code == 0


def test_main_force_mode(tmpdir):
"""Test main CLI in force mode."""
# Create temporary directory for output files
output = tmpdir.join('test_force')
# Set up test command line arguments
args = [
'--mode', 'force',
'--file', f'{TEST_FILE_DIR}/*.xvg',
'--outname', str(output),
'--temperature', '300',
'--velocity', '0.01',
'--res', '1',
'--sigma', '0.1',
'--verbose',
'--plot',
'--save_dataset',
]
runner = CliRunner()
clirunner_result = runner.invoke(main, args)
assert clirunner_result.exit_code == 0


def test_main_no_plot(tmpdir):
"""Test main CLI without plot option."""
# Create temporary directory for output files
output = tmpdir.join('test_no_plot')
# Set up test command line arguments
args = [
'--mode', 'work',
'--file', f'{TEST_FILE_DIR}/*.xvg',
'--outname', str(output),
'--temperature', '300',
'--velocity', '0.01',
'--res', '1',
'--sigma', '0.1',
'--verbose',
'--save_dataset',
]
runner = CliRunner()
clirunner_result = runner.invoke(main, args)
assert clirunner_result.exit_code == 0


def test_main_no_save_dataset(tmpdir):
"""Test main CLI without save_dataset option."""
# Create temporary directory for output files
output = tmpdir.join('test_no_save')
# Set up test command line arguments
args = [
'--mode', 'work',
'--file', f'{TEST_FILE_DIR}/*.xvg',
'--outname', str(output),
'--temperature', '300',
'--velocity', '0.01',
'--res', '1',
'--sigma', '0.1',
'--verbose',
'--plot',
]
runner = CliRunner()
clirunner_result = runner.invoke(main, args)
assert clirunner_result.exit_code == 0


def test_main_minimal(tmpdir):
"""Test main CLI with minimal options."""
# Create temporary directory for output files
output = tmpdir.join('test_minimal')
# Set up test command line arguments
args = [
'--file', f'{TEST_FILE_DIR}/*.xvg',
'--outname', str(output),
'--temperature', '300',
'--velocity', '0.01',
]
runner = CliRunner()
clirunner_result = runner.invoke(main, args)
assert clirunner_result.exit_code == 0
Loading

0 comments on commit b2f5c53

Please sign in to comment.