diff --git a/pulser-core/pulser/sampler/sampler.py b/pulser-core/pulser/sampler/sampler.py index d291987e0..d528d6b1c 100644 --- a/pulser-core/pulser/sampler/sampler.py +++ b/pulser-core/pulser/sampler/sampler.py @@ -56,5 +56,6 @@ def sample( list(seq.declared_channels.keys()), samples_list, seq.declared_channels, + seq._basis_ref, **optionals, ) diff --git a/pulser-core/pulser/sampler/samples.py b/pulser-core/pulser/sampler/samples.py index 8a8a05fc2..625ec0fe6 100644 --- a/pulser-core/pulser/sampler/samples.py +++ b/pulser-core/pulser/sampler/samples.py @@ -10,9 +10,10 @@ from pulser.channels.base_channel import Channel from pulser.channels.eom import BaseEOM from pulser.register import QubitId +from pulser.sequence._basis_ref import _QubitRef if TYPE_CHECKING: - from pulser.sequence._schedule import _EOMSettings + from pulser.sequence._schedule import _EOMSettings, _TimeSlot """Literal constants for addressing.""" _GLOBAL = "Global" @@ -58,7 +59,7 @@ def _default_to_regular(d: dict | defaultdict) -> dict: @dataclass -class _TargetSlot: +class _PulseTargetSlot: """Auxiliary class to store target information. Recopy of the sequence._TimeSlot but without the unrelevant `type` field, @@ -89,9 +90,11 @@ class ChannelSamples: amp: np.ndarray det: np.ndarray phase: np.ndarray - slots: list[_TargetSlot] = field(default_factory=list) + slots: list[_PulseTargetSlot] = field(default_factory=list) eom_blocks: list[_EOMSettings] = field(default_factory=list) - initial_targets: set[QubitId] = field(default_factory=set) + eom_start_buffers: list[tuple[int, int]] = field(default_factory=list) + eom_end_buffers: list[tuple[int, int]] = field(default_factory=list) + target_time_slots: list[_TimeSlot] = field(default_factory=list) def __post_init__(self) -> None: assert len(self.amp) == len(self.det) == len(self.phase) @@ -102,6 +105,15 @@ def __post_init__(self) -> None: for t1, t2 in zip(self.slots, self.slots[1:]): assert t1.tf <= t2.ti # no overlaps on a given channel + @property + def initial_targets(self) -> set[QubitId]: + """Returns the initial targets.""" + return ( + self.target_time_slots[0].targets + if self.target_time_slots + else set() + ) + def extend_duration(self, new_duration: int) -> ChannelSamples: """Extends the duration of the samples. @@ -160,6 +172,23 @@ def _generate_std_samples(self) -> ChannelSamples: return replace(self, **new_samples) + def get_eom_mode_intervals(self) -> list[tuple[int, int]]: + """Returns EOM mode intervals.""" + return [ + ( + block.ti, + block.tf if block.tf is not None else self.duration, + ) + for block in self.eom_blocks + ] + + def in_eom_mode(self, slot: _TimeSlot | _PulseTargetSlot) -> bool: + """States if a time slot is inside an EOM mode block.""" + return any( + start <= slot.ti < end + for start, end in self.get_eom_mode_intervals() + ) + def modulate( self, channel_obj: Channel, max_duration: Optional[int] = None ) -> ChannelSamples: @@ -292,6 +321,9 @@ class SequenceSamples: channels: list[str] samples_list: list[ChannelSamples] _ch_objs: dict[str, Channel] + _basis_ref: dict[str, dict[QubitId, _QubitRef]] = field( + default_factory=dict + ) _slm_mask: _SlmMask = field(default_factory=_SlmMask) _magnetic_field: np.ndarray | None = None _measurement: str | None = None @@ -396,3 +428,8 @@ def __repr__(self) -> str: for chname, cs in zip(self.channels, self.samples_list) ] return "\n\n".join(blocks) + + +# This is just to preserve backwards compatibility after the renaming of +# _TargetSlot to _PulseTarget slot +_TargetSlot = _PulseTargetSlot diff --git a/pulser-core/pulser/sequence/_schedule.py b/pulser-core/pulser/sequence/_schedule.py index 010e1dbe2..eb2e0f774 100644 --- a/pulser-core/pulser/sequence/_schedule.py +++ b/pulser-core/pulser/sequence/_schedule.py @@ -24,7 +24,7 @@ from pulser.channels.base_channel import Channel from pulser.pulse import Pulse from pulser.register.base_register import QubitId -from pulser.sampler.samples import ChannelSamples, _TargetSlot +from pulser.sampler.samples import ChannelSamples, _PulseTargetSlot from pulser.waveforms import ConstantWaveform @@ -137,8 +137,17 @@ def get_samples( channel_slots = [s for s in self.slots if isinstance(s.type, Pulse)] dt = self.get_duration() amp, det, phase = np.zeros(dt), np.zeros(dt), np.zeros(dt) - slots: list[_TargetSlot] = [] - initial_targets = self.slots[0].targets if self.slots else set() + slots: list[_PulseTargetSlot] = [] + target_time_slots: list[_TimeSlot] = [ + s for s in self.slots if s.type == "target" + ] + # Extracting the EOM Buffers + eom_intervals_ti = [block.ti for block in self.eom_blocks] + nb_eom_intervals = len(eom_intervals_ti) + eom_start_buffers = [(0, 0) for _ in range(nb_eom_intervals)] + eom_end_buffers = [(0, 0) for _ in range(nb_eom_intervals)] + in_eom_mode = False + eom_block_n = -1 for ind, s in enumerate(channel_slots): pulse = cast(Pulse, s.type) @@ -156,7 +165,7 @@ def get_samples( if ind < len(channel_slots) - 1 else fall_time ) - slots.append(_TargetSlot(s.ti, tf, s.targets)) + slots.append(_PulseTargetSlot(s.ti, tf, s.targets)) if ignore_detuned_delay_phase and self.is_detuned_delay(pulse): # The phase of detuned delays is not considered @@ -182,8 +191,40 @@ def get_samples( # the same, so the last phase is automatically kept till the end phase[t_start:] = pulse.phase + # Create EOM start and end buffers + for s in self.slots: + if s.ti == -1: + continue + + # If slot is not the first element in schedule + if self.in_eom_mode(s): + # EOM mode starts + if not in_eom_mode: + in_eom_mode = True + eom_block_n += 1 + elif in_eom_mode: + # Buffer when EOM mode is disabled and next slot has 0 amp + in_eom_mode = False + if amp[s.ti] == 0: + eom_end_buffers[eom_block_n] = (s.ti, s.tf) + if ( + eom_block_n + 1 < nb_eom_intervals + and s.tf == eom_intervals_ti[eom_block_n + 1] + and det[s.tf - 1] + == self.eom_blocks[eom_block_n + 1].detuning_off + ): + # Buffer if next is eom and final det matches det_off + eom_start_buffers[eom_block_n + 1] = (s.ti, s.tf) + return ChannelSamples( - amp, det, phase, slots, self.eom_blocks, initial_targets + amp, + det, + phase, + slots, + self.eom_blocks, + eom_start_buffers, + eom_end_buffers, + target_time_slots, ) @overload diff --git a/pulser-core/pulser/sequence/_seq_drawer.py b/pulser-core/pulser/sequence/_seq_drawer.py index 545aa32ae..536dc0977 100644 --- a/pulser-core/pulser/sequence/_seq_drawer.py +++ b/pulser-core/pulser/sequence/_seq_drawer.py @@ -29,7 +29,9 @@ from pulser import Register, Register3D from pulser.channels.base_channel import Channel from pulser.pulse import Pulse -from pulser.sampler.samples import ChannelSamples +from pulser.register.base_register import BaseRegister +from pulser.sampler.sampler import sample +from pulser.sampler.samples import ChannelSamples, SequenceSamples from pulser.waveforms import InterpolatedWaveform # Color scheme @@ -161,82 +163,52 @@ def _give_curves_from_samples( ] -def gather_data(seq: pulser.sequence.Sequence, gather_output: bool) -> dict: +def gather_data( + sampled_seq: SequenceSamples, shown_duration: Optional[int] = None +) -> dict: """Collects the whole sequence data for plotting. Args: - seq: The input sequence of operations on a device. - gather_output: Whether to gather the modulated output curves. + sampled_seq: The samples of a sequence of operations on a device. + shown_duration: If present, is the total duration to be shown in + the X axis. Returns: The data to plot. """ # The minimum time axis length is 100 ns - total_duration = max( - seq.get_duration(include_fall_time=gather_output), 100 - ) + total_duration = max(sampled_seq.max_duration, 100, shown_duration or 100) data: dict[str, Any] = {} - for ch, sch in seq._schedule.items(): - # List of interpolation points - interp_pts: defaultdict[str, list[list[float]]] = defaultdict(list) + for ch, ch_samples in sampled_seq.channel_samples.items(): target: dict[Union[str, tuple[int, int]], Any] = {} # Extracting the EOM Buffers eom_intervals = [ EOMSegment(eom_interval[0], eom_interval[1]) - for eom_interval in sch.get_eom_mode_intervals() + for eom_interval in ch_samples.get_eom_mode_intervals() ] - nb_eom_intervals = len(eom_intervals) - eom_start_buffers = [EOMSegment() for _ in range(nb_eom_intervals)] - eom_end_buffers = [EOMSegment() for _ in range(nb_eom_intervals)] - in_eom_mode = False - eom_block_n = -1 # Last eom interval is extended if eom mode not disabled at the end - if nb_eom_intervals > 0 and seq.get_duration() == eom_intervals[-1].tf: + if ( + len(eom_intervals) > 0 + and ch_samples.duration == eom_intervals[-1].tf + ): eom_intervals[-1].tf = total_duration # sampling the channel schedule - samples = sch.get_samples() - extended_samples = samples.extend_duration(total_duration) - for slot in sch: - if slot.ti == -1: - target["initial"] = slot.targets - continue - else: - # If slot is not the first element in schedule - if sch.in_eom_mode(slot): - # EOM mode starts - if not in_eom_mode: - in_eom_mode = True - eom_block_n += 1 - elif in_eom_mode: - # Buffer when EOM mode is disabled and next slot has 0 amp - in_eom_mode = False - if extended_samples.amp[slot.ti] == 0: - eom_end_buffers[eom_block_n] = EOMSegment( - slot.ti, slot.tf - ) - if ( - eom_block_n + 1 < nb_eom_intervals - and slot.tf == eom_intervals[eom_block_n + 1].ti - and extended_samples.det[slot.tf - 1] - == sch.eom_blocks[eom_block_n + 1].detuning_off - ): - # Buffer if next is eom and final det matches det_off - eom_start_buffers[eom_block_n + 1] = EOMSegment( - slot.ti, slot.tf - ) + extended_samples = ch_samples.extend_duration(total_duration) - if slot.type == "target": - target[(slot.ti, slot.tf - 1)] = slot.targets - continue - if slot.type == "delay": + eom_start_buffers = [ + EOMSegment(eom_interval[0], eom_interval[1]) + for eom_interval in ch_samples.eom_start_buffers + ] + eom_end_buffers = [ + EOMSegment(eom_interval[0], eom_interval[1]) + for eom_interval in ch_samples.eom_end_buffers + ] + + for time_slot in ch_samples.target_time_slots: + if time_slot.ti == -1: + target["initial"] = time_slot.targets continue - pulse = cast(Pulse, slot.type) - for wf_type in ["amplitude", "detuning"]: - wf = getattr(pulse, wf_type) - if isinstance(wf, InterpolatedWaveform): - pts = wf.data_points - pts[:, 0] += slot.ti - interp_pts[wf_type] += pts.tolist() + target[(time_slot.ti, time_slot.tf - 1)] = time_slot.targets # Store everything data[ch] = ChannelDrawContent( @@ -246,43 +218,39 @@ def gather_data(seq: pulser.sequence.Sequence, gather_output: bool) -> dict: eom_start_buffers, eom_end_buffers, ) - if interp_pts: - data[ch].interp_pts = dict(interp_pts) - if hasattr(seq, "_measurement"): - data["measurement"] = seq._measurement + + if sampled_seq._measurement is not None: + data["measurement"] = sampled_seq._measurement data["total_duration"] = total_duration return data -def draw_sequence( - seq: pulser.sequence.Sequence, +def _draw_channel_content( + sampled_seq: SequenceSamples, + register: Optional[BaseRegister] = None, sampling_rate: Optional[float] = None, draw_phase_area: bool = False, - draw_interp_pts: bool = True, draw_phase_shifts: bool = False, - draw_register: bool = False, draw_input: bool = True, draw_modulation: bool = False, draw_phase_curve: bool = False, -) -> tuple[Figure | None, Figure]: - """Draws the entire sequence. + shown_duration: Optional[int] = None, +) -> tuple[Figure | None, Figure, Any, dict]: + """Draws samples of a sequence. Args: - seq: The input sequence of operations on a device. + sampled_seq: The input samples of a sequence of operations. + register: If present, draw the register before the pulse + sequence, with a visual indication (square halo) around the qubits + masked by the SLM. sampling_rate: Sampling rate of the effective pulse used by the solver. If present, plots the effective pulse alongside the input pulse. draw_phase_area: Whether phase and area values need to be shown as text on the plot, defaults to False. If `draw_phase_curve=True`, phase values are ommited. - draw_interp_pts: When the sequence has pulses with waveforms of - type InterpolatedWaveform, draws the points of interpolation on - top of the respective waveforms (defaults to True). draw_phase_shifts: Whether phase shift and reference information should be added to the plot, defaults to False. - draw_register: Whether to draw the register before the pulse - sequence, with a visual indication (square halo) around the qubits - masked by the SLM, defaults to False. draw_input: Draws the programmed pulses on the channels, defaults to True. draw_modulation: Draws the expected channel output, defaults to @@ -290,6 +258,7 @@ def draw_sequence( is skipped unless 'draw_input=False'. draw_phase_curve: Draws the changes in phase in its own curve (ignored if the phase doesn't change throughout the channel). + shown_duration: Total duration to be shown in the X axis. """ def phase_str(phi: float) -> str: @@ -302,13 +271,14 @@ def phase_str(phi: float) -> str: else: return rf"{value:.2g}$\pi$" - n_channels = len(seq.declared_channels) + n_channels = len(sampled_seq.channels) if not n_channels: raise RuntimeError("Can't draw an empty sequence.") - data = gather_data(seq, gather_output=draw_modulation) + + data = gather_data(sampled_seq, shown_duration) total_duration = data["total_duration"] time_scale = 1e3 if total_duration > 1e4 else 1 - for ch in seq._schedule: + for ch in sampled_seq.channels: if np.count_nonzero(data[ch].samples.det) > 0: data[ch].curves_on["detuning"] = True if draw_phase_curve and np.count_nonzero(data[ch].samples.phase) > 0: @@ -322,11 +292,11 @@ def phase_str(phi: float) -> str: eom_box = dict(boxstyle="round", facecolor="lightsteelblue") # Draw masked register - if draw_register: - pos = np.array(seq.register._coords) - if isinstance(seq.register, Register3D): + if register: + pos = np.array(register._coords) + if isinstance(register, Register3D): labels = "xyz" - fig_reg, axes_reg = seq.register._initialize_fig_axes_projection( + fig_reg, axes_reg = register._initialize_fig_axes_projection( pos, blockade_radius=35, draw_half_radius=True, @@ -336,12 +306,12 @@ def phase_str(phi: float) -> str: for ax_reg, (ix, iy) in zip( axes_reg, combinations(np.arange(3), 2) ): - seq.register._draw_2D( + register._draw_2D( ax=ax_reg, pos=pos, - ids=seq.register._ids, + ids=register._ids, plane=(ix, iy), - masked_qubits=seq._slm_mask_targets, + masked_qubits=sampled_seq._slm_mask.targets, ) ax_reg.set_title( "Masked register projected onto\n the " @@ -350,22 +320,22 @@ def phase_str(phi: float) -> str: + "-plane" ) - elif isinstance(seq.register, Register): - fig_reg, ax_reg = seq.register._initialize_fig_axes( + elif isinstance(register, Register): + fig_reg, ax_reg = register._initialize_fig_axes( pos, blockade_radius=35, draw_half_radius=True, ) - seq.register._draw_2D( + register._draw_2D( ax=ax_reg, pos=pos, - ids=seq.register._ids, - masked_qubits=seq._slm_mask_targets, + ids=register._ids, + masked_qubits=sampled_seq._slm_mask.targets, ) ax_reg.set_title("Masked register", pad=10) ratios = [ - SIZE_PER_WIDTH[data[ch].n_axes_on] for ch in seq.declared_channels + SIZE_PER_WIDTH[data[ch].n_axes_on] for ch in sampled_seq.channels ] fig = plt.figure( constrained_layout=False, @@ -374,7 +344,7 @@ def phase_str(phi: float) -> str: gs = fig.add_gridspec(n_channels, 1, hspace=0.075, height_ratios=ratios) ch_axes = {} - for i, (ch, gs_) in enumerate(zip(seq.declared_channels, gs)): + for i, (ch, gs_) in enumerate(zip(sampled_seq.channels, gs)): ax = fig.add_subplot(gs_) for side in ("top", "bottom", "left", "right"): ax.spines[side].set_color("none") @@ -412,8 +382,8 @@ def phase_str(phi: float) -> str: t_max = final_t * 1.05 for ch, axes in ch_axes.items(): - ch_obj = seq.declared_channels[ch] ch_data = data[ch] + ch_obj = sampled_seq._ch_objs[ch] ch_eom_intervals = data[ch].eom_intervals ch_eom_start_buffers = data[ch].eom_start_buffers ch_eom_end_buffers = data[ch].eom_end_buffers @@ -490,49 +460,54 @@ def phase_str(phi: float) -> str: if draw_phase_area: top = False # Variable to track position of box, top or center. print_phase = not draw_phase_curve and any( - seq_.type.phase != 0 - for seq_ in seq._schedule[ch] - if isinstance(seq_.type, Pulse) + np.any(ch_data.samples.phase[slot.ti : slot.tf] != 0) + for slot in ch_data.samples.slots ) - for pulse_num, seq_ in enumerate(seq._schedule[ch]): - # Select only `Pulse` objects - if isinstance(seq_.type, Pulse): - if sampling_rate: - area_val = ( - np.sum(yseff[0][seq_.ti : seq_.tf]) * 1e-3 / np.pi - ) - else: - area_val = seq_.type.amplitude.integral / np.pi - phase_val = seq_.type.phase - x_plot = (seq_.ti + seq_.tf) / 2 / time_scale - if ( - seq._schedule[ch][pulse_num - 1].type == "target" - or not top - ): - y_plot = np.max(seq_.type.amplitude.samples) / 2 - top = True # Next box at the top. - elif top: - y_plot = np.max(seq_.type.amplitude.samples) - top = False # Next box at the center. - area_fmt = ( - r"A: $\pi$" - if round(area_val, 2) == 1 - else rf"A: {area_val:.2g}$\pi$" + + for slot in ch_data.samples.slots: + if sampling_rate: + area_val = ( + np.sum(yseff[0][slot.ti : slot.tf]) * 1e-3 / np.pi ) - if not print_phase: - txt = area_fmt - else: - phase_fmt = rf"$\phi$: {phase_str(phase_val)}" - txt = "\n".join([phase_fmt, area_fmt]) - axes[0].text( - x_plot, - y_plot, - txt, - fontsize=10, - ha="center", - va="center", - bbox=area_ph_box, + else: + area_val = ( + np.sum(ch_data.samples.amp[slot.ti : slot.tf]) + * 1e-3 + / np.pi ) + phase_val = ch_data.samples.phase[slot.tf - 1] + x_plot = (slot.ti + slot.tf) / 2 / time_scale + target_slot_tf_list = [ + target_slot.tf + for target_slot in sampled_seq.channel_samples[ + ch + ].target_time_slots + ] + if slot.ti in target_slot_tf_list or not top: + y_plot = np.max(ch_data.samples.amp[slot.ti : slot.tf]) / 2 + top = True # Next box at the top. + elif top: + y_plot = np.max(ch_data.samples.amp[slot.ti : slot.tf]) + top = False # Next box at the center. + area_fmt = ( + r"A: $\pi$" + if round(area_val, 2) == 1 + else rf"A: {area_val:.2g}$\pi$" + ) + if not print_phase: + txt = area_fmt + else: + phase_fmt = rf"$\phi$: {phase_str(phase_val)}" + txt = "\n".join([phase_fmt, area_fmt]) + axes[0].text( + x_plot, + y_plot, + txt, + fontsize=10, + ha="center", + va="center", + bbox=area_ph_box, + ) target_regions = [] # [[start1, [targets1], end1],...] for coords in ch_data.target: @@ -543,7 +518,7 @@ def phase_str(phi: float) -> str: if coords == "initial": x = t_min + final_t * 0.005 target_regions.append([0, targets]) - if seq.declared_channels[ch].addressing == "Global": + if ch_obj.addressing == "Global": axes[0].text( x, amp_top * 0.98, @@ -563,7 +538,7 @@ def phase_str(phi: float) -> str: ha="left", bbox=q_box, ) - phase = seq._basis_ref[basis][targets[0]].phase[0] + phase = sampled_seq._basis_ref[basis][targets[0]].phase[0] if phase and draw_phase_shifts: msg = r"$\phi=$" + phase_str(phase) axes[0].text( @@ -580,7 +555,7 @@ def phase_str(phi: float) -> str: target_regions.append( [tf + 1 / time_scale, targets] ) # New one - phase = seq._basis_ref[basis][targets[0]].phase[ + phase = sampled_seq._basis_ref[basis][targets[0]].phase[ tf * time_scale + 1 ] for ax in axes: @@ -605,6 +580,7 @@ def phase_str(phi: float) -> str: fontsize=12, bbox=ph_box, ) + # Terminate the last open regions if target_regions: target_regions[-1].append(final_t) @@ -616,7 +592,7 @@ def phase_str(phi: float) -> str: end = cast(float, end) # All targets have the same ref, so we pick q = targets_[0] - ref = seq._basis_ref[basis][q].phase + ref = sampled_seq._basis_ref[basis][q].phase if end != total_duration - 1 or "measurement" in data: end += 1 / time_scale for t_, delta in ref.changes(start, end, time_scale=time_scale): @@ -653,11 +629,11 @@ def phase_str(phi: float) -> str: bbox=eom_box, ) # Draw the SLM mask - if seq._slm_mask_targets and seq._slm_mask_time: - tf_m = seq._slm_mask_time[1] + if sampled_seq._slm_mask.targets and sampled_seq._slm_mask.end: + tf_m = sampled_seq._slm_mask.end for ax in axes: ax.axvspan(0, tf_m, color="black", alpha=0.1, zorder=-100) - tgt_strs = [str(q) for q in seq._slm_mask_targets] + tgt_strs = [str(q) for q in sampled_seq._slm_mask.targets] tgt_txt_x = final_t * 0.005 tgt_txt_y = axes[-1].get_ylim()[0] tgt_str = "\n".join(tgt_strs) @@ -707,6 +683,130 @@ def phase_str(phi: float) -> str: if ax_lims[i][0] < 0: ax.axhline(0, **hline_kwargs) + return (fig_reg if register else None, fig, ch_axes, data) + + +def draw_samples( + sampled_seq: SequenceSamples, + register: Optional[BaseRegister] = None, + sampling_rate: Optional[float] = None, + draw_phase_area: bool = False, + draw_phase_shifts: bool = False, + draw_phase_curve: bool = False, +) -> tuple[Figure | None, Figure]: + """Draws a SequenceSamples. + + Args: + sampled_seq: The input samples of a sequence of operations. + register: If present, draw the register before the pulse + sequence samples, with a visual indication (square halo) + around the qubits masked by the SLM. + sampling_rate: Sampling rate of the effective pulse used by + the solver. If present, plots the effective pulse alongside the + input pulse. + draw_phase_area: Whether phase and area values need to be shown + as text on the plot, defaults to False. If `draw_phase_curve=True`, + phase values are ommited. + draw_phase_shifts: Whether phase shift and reference information + should be added to the plot, defaults to False. + draw_phase_curve: Draws the changes in phase in its own curve (ignored + if the phase doesn't change throughout the channel). + """ + slot_tfs = [ + ch_samples.slots[-1].tf + for ch_samples in sampled_seq.channel_samples.values() + ] + max_slot_tf = max(slot_tfs) if len(slot_tfs) > 0 else None + (fig_reg, fig, ch_axes, data) = _draw_channel_content( + sampled_seq, + register, + sampling_rate, + draw_phase_area, + draw_phase_shifts, + draw_input=True, + draw_modulation=False, + draw_phase_curve=draw_phase_curve, + shown_duration=max_slot_tf, + ) + + return (fig_reg, fig) + + +def draw_sequence( + seq: pulser.sequence.Sequence, + sampling_rate: Optional[float] = None, + draw_phase_area: bool = False, + draw_interp_pts: bool = True, + draw_phase_shifts: bool = False, + draw_register: bool = False, + draw_input: bool = True, + draw_modulation: bool = False, + draw_phase_curve: bool = False, +) -> tuple[Figure | None, Figure]: + """Draws the entire sequence. + + Args: + seq: The input sequence of operations on a device. + sampling_rate: Sampling rate of the effective pulse used by + the solver. If present, plots the effective pulse alongside the + input pulse. + draw_phase_area: Whether phase and area values need to be shown + as text on the plot, defaults to False. If `draw_phase_curve=True`, + phase values are ommited. + draw_interp_pts: When the sequence has pulses with waveforms of + type InterpolatedWaveform, draws the points of interpolation on + top of the respective waveforms (defaults to True). + draw_phase_shifts: Whether phase shift and reference information + should be added to the plot, defaults to False. + draw_register: Whether to draw the register before the pulse + sequence, with a visual indication (square halo) around the qubits + masked by the SLM, defaults to False. + draw_input: Draws the programmed pulses on the channels, defaults + to True. + draw_modulation: Draws the expected channel output, defaults to + False. If the channel does not have a defined 'mod_bandwidth', this + is skipped unless 'draw_input=False'. + draw_phase_curve: Draws the changes in phase in its own curve (ignored + if the phase doesn't change throughout the channel). + """ + # Sample the sequence and get the data to plot + shown_duration = seq.get_duration(include_fall_time=draw_modulation) + sampled_seq = sample(seq) + + (fig_reg, fig, ch_axes, data) = _draw_channel_content( + sampled_seq, + seq.register if draw_register else None, + sampling_rate, + draw_phase_area, + draw_phase_shifts, + draw_input, + draw_modulation, + draw_phase_curve, + shown_duration, + ) + + # Gather additional data for sequence specific drawing + for ch, sch in seq._schedule.items(): + interp_pts: defaultdict[str, list[list[float]]] = defaultdict(list) + + for slot in sch: + if slot.ti == -1 or slot.type in ["target", "delay"]: + continue + + pulse = cast(Pulse, slot.type) + for wf_type in ["amplitude", "detuning"]: + wf = getattr(pulse, wf_type) + if isinstance(wf, InterpolatedWaveform): + pts = wf.data_points + pts[:, 0] += slot.ti + interp_pts[wf_type] += pts.tolist() + + if interp_pts: + data[ch].interp_pts = dict(interp_pts) + + for ch, axes in ch_axes.items(): + ch_data = data[ch] + if draw_interp_pts: for qty in ("amplitude", "detuning"): if qty in ch_data.interp_pts and ch_data.curves_on[qty]: @@ -714,4 +814,4 @@ def phase_str(phi: float) -> str: pts = np.array(ch_data.interp_pts[qty]) axes[ind].scatter(pts[:, 0], pts[:, 1], color=COLORS[ind]) - return (fig_reg if draw_register else None, fig) + return (fig_reg, fig) diff --git a/pulser-core/requirements.txt b/pulser-core/requirements.txt index 2621a2ccb..b932eddef 100644 --- a/pulser-core/requirements.txt +++ b/pulser-core/requirements.txt @@ -1,5 +1,5 @@ jsonschema matplotlib # Numpy 1.20 introduces type hints, 1.24.0 breaks matplotlib < 3.6.1 -numpy >= 1.20, != 1.24.0 +numpy >= 1.20, != 1.24.0, <1.25 scipy diff --git a/pulser-simulation/pulser_simulation/simulation.py b/pulser-simulation/pulser_simulation/simulation.py index bd3a1a67c..40916ebba 100644 --- a/pulser-simulation/pulser_simulation/simulation.py +++ b/pulser-simulation/pulser_simulation/simulation.py @@ -32,7 +32,7 @@ from pulser.devices._device_datacls import BaseDevice from pulser.register.base_register import BaseRegister, QubitId from pulser.result import SampledResult -from pulser.sampler.samples import SequenceSamples, _TargetSlot +from pulser.sampler.samples import SequenceSamples, _PulseTargetSlot from pulser.sequence._seq_drawer import draw_sequence from pulser_simulation.qutip_result import QutipResult from pulser_simulation.simconfig import SimConfig @@ -500,7 +500,7 @@ def _extract_samples(self) -> None: samples = self.samples_obj.to_nested_dict(all_local=local_noises) def add_noise( - slot: _TargetSlot, + slot: _PulseTargetSlot, samples_dict: Mapping[QubitId, dict[str, np.ndarray]], is_global_pulse: bool, ) -> None: diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 837d3d44b..c29c7f372 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -556,7 +556,7 @@ def test_switch_device_up( @pytest.mark.parametrize("mappable_reg", [False, True]) @pytest.mark.parametrize("parametrized", [False, True]) -def test_switch_device_eom(reg, mappable_reg, parametrized): +def test_switch_device_eom(reg, mappable_reg, parametrized, patch_plt_show): # Sequence with EOM blocks seq = init_seq( reg, @@ -602,6 +602,9 @@ def test_switch_device_eom(reg, mappable_reg, parametrized): assert og_eom_block.rabi_freq == mod_eom_block.rabi_freq assert og_eom_block.detuning_off != mod_eom_block.detuning_off + # Test drawing in eom mode + seq.draw() + def test_target(reg, device): seq = Sequence(reg, device) diff --git a/tests/test_sequence_sampler.py b/tests/test_sequence_sampler.py index fb253fd0c..9ae7e58a8 100644 --- a/tests/test_sequence_sampler.py +++ b/tests/test_sequence_sampler.py @@ -23,6 +23,7 @@ from pulser.devices import Device, MockDevice from pulser.pulse import Pulse from pulser.sampler import sample +from pulser.sequence._seq_drawer import draw_samples from pulser.waveforms import BlackmanWaveform, RampWaveform # Helpers @@ -239,6 +240,9 @@ def test_eom_modulation(mod_device, disable_eom): input_samples = sample( seq, extended_duration=full_duration ).channel_samples["ch0"] + assert input_samples.in_eom_mode(input_samples.slots[-1]) == ( + not disable_eom + ) mod_samples = sample(seq, modulation=True, extended_duration=full_duration) chan = seq.declared_channels["ch0"] for qty in ("amp", "det"): @@ -399,6 +403,22 @@ def test_phase_sampling(mod_device): np.testing.assert_array_equal(expected_phase, got_phase) +@pytest.mark.parametrize("modulation", [True, False]) +@pytest.mark.parametrize("draw_phase_area", [True, False]) +@pytest.mark.parametrize("draw_phase_shifts", [True, False]) +@pytest.mark.parametrize("draw_phase_curve", [True, False]) +def test_draw_samples( + mod_seq, modulation, draw_phase_area, draw_phase_curve, draw_phase_shifts +): + sampled_seq = sample(mod_seq, modulation=modulation) + draw_samples( + sampled_seq, + draw_phase_area=draw_phase_area, + draw_phase_shifts=draw_phase_shifts, + draw_phase_curve=draw_phase_curve, + ) + + # Fixtures