Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate draw_samples and draw_sequence #533

Merged
merged 32 commits into from
Jun 19, 2023
Merged
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0f32556
separate draw_samples and draw_sequence
dakk Jun 7, 2023
daafc55
rebase
dakk Jun 7, 2023
4501e77
restore target rendering
dakk Jun 7, 2023
91166cf
fix linters and types
dakk Jun 7, 2023
2fb542d
move measurement drawing and draw_interp_pts to draw_sequenece
dakk Jun 8, 2023
76276e1
- add time_slots of type target to ChannelSamples
dakk Jun 9, 2023
9ee018f
Fix samples.py typing
dakk Jun 9, 2023
63ab9e4
remove useless comments
dakk Jun 9, 2023
807c8e8
- fix _seq_drawer parameters names
dakk Jun 9, 2023
0bcbf57
fix linting
dakk Jun 9, 2023
ab4f2e3
move optional register drawing to draw_samples
dakk Jun 9, 2023
09c580e
splitting of drawing of target regions
dakk Jun 12, 2023
1363f2c
fix linters
dakk Jun 12, 2023
b7b725f
move draw_phase_shifts outside the loop
dakk Jun 12, 2023
3f4afd4
remove duplicate code from draw_sequence
dakk Jun 12, 2023
6757303
separate draw_channel_content from draw_samples
dakk Jun 12, 2023
2bd97b6
restore _seq_drawer boxes definition position
dakk Jun 12, 2023
70b3d2c
minor edits on seq_drawer
dakk Jun 12, 2023
652484d
adapt draw_phase_area for using sampled_seq in _seq_drawer
dakk Jun 13, 2023
e89d060
- move gather_data to draw_channel_content
dakk Jun 13, 2023
0309a42
- add _basis_ref to SequenceSamples
dakk Jun 13, 2023
3f56a89
- move phase_str into _draw_channel_content
dakk Jun 13, 2023
69e9a1f
refactoring of _seq_drawer
dakk Jun 13, 2023
2a7cf7e
Refactoring of _seq_drawer.py
dakk Jun 13, 2023
ae3b974
fix typo
dakk Jun 13, 2023
4706247
fix EOM drawing in draw_sequence
dakk Jun 13, 2023
33b598e
remove useless if
dakk Jun 13, 2023
38834ab
- test_draw_samples
dakk Jun 16, 2023
f67d642
- add eom_start_buffers and eom_end_buffers in ChannelSamples
dakk Jun 19, 2023
a0c3692
preserve backward compatibility for _TargetSlot
dakk Jun 19, 2023
6c22203
Pin numpy version to < 1.25
dakk Jun 19, 2023
b39a25d
use eom_blocks for eom_intervals_ti creation
dakk Jun 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 137 additions & 110 deletions pulser-core/pulser/sequence/_seq_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def gather_data(
if shown_duration is not None:
total_duration = shown_duration
else:
total_duration = max(sampled_seq.max_duration, 100)
total_duration = sampled_seq.max_duration
total_duration = max(total_duration, 100)
data: dict[str, Any] = {}
for ch, ch_samples in sampled_seq.channel_samples.items():
target: dict[Union[str, tuple[int, int]], Any] = {}
Expand Down Expand Up @@ -254,11 +255,23 @@ def gather_data(
return data


def _phase_str(phi: float) -> str:
"""Formats a phase value for printing."""
value = (((phi + np.pi) % (2 * np.pi)) - np.pi) / np.pi
if value == -1:
return r"$\pi$"
elif value == 0:
return "0" # pragma: no cover - just for safety
else:
return rf"{value:.2g}$\pi$"


def _draw_channel_content(
data: dict,
dakk marked this conversation as resolved.
Show resolved Hide resolved
sampled_seq: SequenceSamples,
register: Optional[BaseRegister] = None,
sampling_rate: Optional[float] = None,
draw_phase_area: bool = False,
draw_input: bool = True,
draw_modulation: bool = False,
a-corni marked this conversation as resolved.
Show resolved Hide resolved
draw_phase_curve: bool = False,
Expand All @@ -274,6 +287,9 @@ def _draw_channel_content(
sampling_rate: Sampling rate of the effective pulse used by
the solver. If present, plots the effective pulse alongside the
input pulse.
draw_phase_area: Whether phase and area values need to be shown
as text on the plot, defaults to False. If `draw_phase_curve=True`,
phase values are ommited.
draw_input: Draws the programmed pulses on the channels, defaults
to True.
draw_modulation: Draws the expected channel output, defaults to
Expand All @@ -287,6 +303,7 @@ def _draw_channel_content(

# Boxes for qubit and phase text
q_box = dict(boxstyle="round", facecolor="orange")
area_ph_box = dict(boxstyle="round", facecolor="ghostwhite", alpha=0.7)
eom_box = dict(boxstyle="round", facecolor="lightsteelblue")
slm_box = dict(boxstyle="round", alpha=0.4, facecolor="grey", hatch="//")
dakk marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -462,6 +479,79 @@ def _draw_channel_content(
special_kwargs = dict(labelpad=10) if i == 0 else {}
ax.set_ylabel(LABELS[i], fontsize=14, **special_kwargs)

if draw_phase_area:
top = False # Variable to track position of box, top or center.
print_phase = not draw_phase_curve and any(
any(
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
sampled_seq.channel_samples[ch].phase[slot.ti : slot.tf]
== 0
)
for slot in sampled_seq.channel_samples[ch].slots
)

for slot in sampled_seq.channel_samples[ch].slots:
if sampling_rate:
area_val = (
np.sum(yseff[0][slot.ti : slot.tf]) * 1e-3 / np.pi
)
else:
area_val = (
np.sum(
sampled_seq.channel_samples[ch].amp[
slot.ti : slot.tf
]
)
* 1e-3
/ np.pi
)
phase_val = sampled_seq.channel_samples[ch].phase[
slot.ti : slot.tf
][-1]
x_plot = (slot.ti + slot.tf) / 2 / time_scale
if (
slot.ti
in [
target_slot.tf
for target_slot in sampled_seq.channel_samples[
ch
].target_time_slots
]
or not top
):
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
y_plot = (
np.max(
sampled_seq.channel_samples[ch].amp[
slot.ti : slot.tf
]
)
/ 2
)
top = True # Next box at the top.
elif top:
y_plot = np.max(
sampled_seq.channel_samples[ch].amp[slot.ti : slot.tf]
)
top = False # Next box at the center.
area_fmt = (
r"A: $\pi$"
if round(area_val, 2) == 1
else rf"A: {area_val:.2g}$\pi$"
)
if not print_phase:
txt = area_fmt
else:
phase_fmt = rf"$\phi$: {_phase_str(phase_val)}"
txt = "\n".join([phase_fmt, area_fmt])
axes[0].text(
x_plot,
y_plot,
txt,
fontsize=10,
ha="center",
va="center",
bbox=area_ph_box,
)

target_regions = [] # [[start1, [targets1], end1],...]
for coords in ch_data.target:
targets = list(ch_data.target[coords])
Expand Down Expand Up @@ -545,13 +635,51 @@ def _draw_channel_content(
bbox=slm_box,
)

hline_kwargs = dict(linestyle="-", linewidth=0.5, color="grey")
if "measurement" in data:
msg = f"Basis: {data['measurement']}"
if len(axes) == 1:
mid_ax = axes[0]
mid_point = (amp_top + amp_bottom) / 2
fontsize = 12
else:
mid_ax = axes[-1]
mid_point = (
ax_lims[-1][1]
if len(axes) == 2
else ax_lims[-1][0] + sum(ax_lims[-1]) * 1.5
)
fontsize = 14

for ax in axes:
ax.axvspan(final_t, t_max, color="midnightblue", alpha=1)

mid_ax.text(
final_t * 1.025,
mid_point,
msg,
ha="center",
va="center",
fontsize=fontsize,
color="white",
rotation=90,
)
hline_kwargs["xmax"] = 0.95

for i, ax in enumerate(axes):
if i > 0:
ax.axhline(ax_lims[i][1], **hline_kwargs)
if ax_lims[i][0] < 0:
ax.axhline(0, **hline_kwargs)

return (fig_reg if register else None, fig, ch_axes, data)


def draw_samples(
sampled_seq: SequenceSamples,
register: Optional[BaseRegister] = None,
sampling_rate: Optional[float] = None,
draw_phase_area: bool = False,
draw_input: bool = True,
draw_modulation: bool = False,
draw_phase_curve: bool = False,
Expand All @@ -566,6 +694,9 @@ def draw_samples(
sampling_rate: Sampling rate of the effective pulse used by
the solver. If present, plots the effective pulse alongside the
input pulse.
draw_phase_area: Whether phase and area values need to be shown
as text on the plot, defaults to False. If `draw_phase_curve=True`,
phase values are ommited.
draw_input: Draws the programmed pulses on the channels, defaults
to True.
draw_modulation: Draws the expected channel output, defaults to
Expand Down Expand Up @@ -625,19 +756,8 @@ def draw_sequence(
draw_phase_curve: Draws the changes in phase in its own curve (ignored
if the phase doesn't change throughout the channel).
"""

def phase_str(phi: float) -> str:
"""Formats a phase value for printing."""
value = (((phi + np.pi) % (2 * np.pi)) - np.pi) / np.pi
if value == -1:
return r"$\pi$"
elif value == 0:
return "0" # pragma: no cover - just for safety
else:
return rf"{value:.2g}$\pi$"

# Sample the sequence and get the data to plot
sampled_seq = sample(seq, modulation=draw_modulation)
sampled_seq = sample(seq)
a-corni marked this conversation as resolved.
Show resolved Hide resolved
shown_duration = seq.get_duration(include_fall_time=draw_modulation)
dakk marked this conversation as resolved.
Show resolved Hide resolved
data = gather_data(sampled_seq, shown_duration)

Expand All @@ -661,14 +781,14 @@ def phase_str(phi: float) -> str:
data[ch].interp_pts = dict(interp_pts)

# Boxes for qubit and phase text
area_ph_box = dict(boxstyle="round", facecolor="ghostwhite", alpha=0.7)
ph_box = dict(boxstyle="round", facecolor="ghostwhite")

(fig_reg, fig, ch_axes, data) = _draw_channel_content(
data,
sampled_seq,
seq.register if draw_register else None,
sampling_rate,
draw_phase_area,
draw_input,
draw_modulation,
draw_phase_curve,
Expand All @@ -679,7 +799,6 @@ def phase_str(phi: float) -> str:
t = np.arange(total_duration) / time_scale
final_t = t[-1]
t_min = -final_t * 0.03
t_max = final_t * 1.05

for ch, axes in ch_axes.items():
ch_obj = seq.declared_channels[ch]
Expand Down Expand Up @@ -718,61 +837,6 @@ def phase_str(phi: float) -> str:
]
ax_lims = [ax_lims[i] for i in ch_data.curves_on_indices()]

if draw_phase_area:
top = False # Variable to track position of box, top or center.
print_phase = not draw_phase_curve and any(
seq_.type.phase != 0
for seq_ in seq._schedule[ch]
if isinstance(seq_.type, Pulse)
)
for pulse_num, seq_ in enumerate(seq._schedule[ch]):
# Select only `Pulse` objects
if isinstance(seq_.type, Pulse):
if sampling_rate:
area_val = (
np.sum(yseff[0][seq_.ti : seq_.tf]) * 1e-3 / np.pi
)
else:
area_val = (
np.sum(
sampled_seq.channel_samples[ch].amp[
slot.ti : slot.tf
]
)
* 1e-3
/ np.pi
)
phase_val = seq_.type.phase
x_plot = (seq_.ti + seq_.tf) / 2 / time_scale
if (
seq._schedule[ch][pulse_num - 1].type == "target"
or not top
):
y_plot = np.max(seq_.type.amplitude.samples) / 2
top = True # Next box at the top.
elif top:
y_plot = np.max(seq_.type.amplitude.samples)
top = False # Next box at the center.
area_fmt = (
r"A: $\pi$"
if round(area_val, 2) == 1
else rf"A: {area_val:.2g}$\pi$"
)
if not print_phase:
txt = area_fmt
else:
phase_fmt = rf"$\phi$: {phase_str(phase_val)}"
txt = "\n".join([phase_fmt, area_fmt])
axes[0].text(
x_plot,
y_plot,
txt,
fontsize=10,
ha="center",
va="center",
bbox=area_ph_box,
)

# Draw target regions phase_shifts
if draw_phase_shifts:
Copy link
Collaborator

@a-corni a-corni Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end, why don't you pass _basis_ref to SequenceSamples when you sample a Sequence (just like _measurement and _magnetic_field) ?
Then you will be able to move this part of code in draw_channel_content.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can do this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did; if I didn't miss anything should be ok now; it is still missing coverage, but I wait for your approval first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, the coverage issues raises on EOM related parts: 68, 84-95, 211-213, 216-218, 230.
If before this pull request it was working, I think that using ChannelSamples.target_time_slots is not enough for handling this part.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am noticing an issue in the tutorial Output Modulation and EOM. The buffers before and after the EOM mode is switched on (the time slots that are just before and just after the eom interval) are not represented...
Obtained:
image
Should be:
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see; I think the problem is in the gather_data part where we create eom buffers starting from ChannelSamples.target_time_slots; the drawing part is almost unchanged

Copy link
Contributor Author

@dakk dakk Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I further investigate it, but I cannot figure out what is the problem.

EDIT:
Just fixed

target_regions = [] # [[start1, [targets1], end1],...]
Expand All @@ -785,7 +849,7 @@ def phase_str(phi: float) -> str:
if seq.declared_channels[ch].addressing != "Global":
phase = seq._basis_ref[basis][targets[0]].phase[0]
if phase and draw_phase_shifts:
msg = r"$\phi=$" + phase_str(phase)
msg = r"$\phi=$" + _phase_str(phase)
axes[0].text(
0,
max_amp * 1.1,
Expand All @@ -804,7 +868,7 @@ def phase_str(phi: float) -> str:
tf * time_scale + 1
]
if phase:
msg = r"$\phi=$" + phase_str(phase)
msg = r"$\phi=$" + _phase_str(phase)
wrd_len = len(max(tgt_strs, key=len))
x = tf + final_t * 0.01 * (wrd_len + 1)
axes[0].text(
Expand Down Expand Up @@ -833,7 +897,7 @@ def phase_str(phi: float) -> str:
conf = dict(linestyle="--", linewidth=1.5, color="black")
for ax in axes:
ax.axvline(t_, **conf)
msg = "\u27F2 " + phase_str(delta)
msg = "\u27F2 " + _phase_str(delta)
axes[0].text(
t_ - final_t * 8e-3,
max_amp * 1.1,
Expand All @@ -843,43 +907,6 @@ def phase_str(phi: float) -> str:
bbox=ph_box,
)

hline_kwargs = dict(linestyle="-", linewidth=0.5, color="grey")
if "measurement" in data:
msg = f"Basis: {data['measurement']}"
if len(axes) == 1:
mid_ax = axes[0]
mid_point = (amp_top + amp_bottom) / 2
fontsize = 12
else:
mid_ax = axes[-1]
mid_point = (
ax_lims[-1][1]
if len(axes) == 2
else ax_lims[-1][0] + sum(ax_lims[-1]) * 1.5
)
fontsize = 14

for ax in axes:
ax.axvspan(final_t, t_max, color="midnightblue", alpha=1)

mid_ax.text(
final_t * 1.025,
mid_point,
msg,
ha="center",
va="center",
fontsize=fontsize,
color="white",
rotation=90,
)
hline_kwargs["xmax"] = 0.95

for i, ax in enumerate(axes):
if i > 0:
ax.axhline(ax_lims[i][1], **hline_kwargs)
if ax_lims[i][0] < 0:
ax.axhline(0, **hline_kwargs)

if draw_interp_pts:
for qty in ("amplitude", "detuning"):
if qty in ch_data.interp_pts and ch_data.curves_on[qty]:
Expand Down