diff --git a/glue/cirq/stimcirq/__init__.py b/glue/cirq/stimcirq/__init__.py index d3a8e1aa..2af970b7 100644 --- a/glue/cirq/stimcirq/__init__.py +++ b/glue/cirq/stimcirq/__init__.py @@ -3,6 +3,7 @@ from ._cx_swap_gate import CXSwapGate from ._cz_swap_gate import CZSwapGate from ._det_annotation import DetAnnotation +from ._feedback_pauli import FeedbackPauli from ._obs_annotation import CumulativeObservableAnnotation from ._shift_coords_annotation import ShiftCoordsAnnotation from ._stim_sampler import StimSampler @@ -19,6 +20,7 @@ JSON_RESOLVERS_DICT = { "CumulativeObservableAnnotation": CumulativeObservableAnnotation, "DetAnnotation": DetAnnotation, + "FeedbackPauli": FeedbackPauli, "MeasureAndOrResetGate": MeasureAndOrResetGate, "ShiftCoordsAnnotation": ShiftCoordsAnnotation, "SweepPauli": SweepPauli, diff --git a/glue/cirq/stimcirq/_cirq_to_stim.py b/glue/cirq/stimcirq/_cirq_to_stim.py index 66a4a303..30b0c31f 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim.py +++ b/glue/cirq/stimcirq/_cirq_to_stim.py @@ -6,8 +6,6 @@ import cirq import stim -from ._i_error_gate import IErrorGate -from ._ii_error_gate import IIErrorGate from ._ii_gate import IIGate @@ -142,6 +140,52 @@ def cirq_circuit_to_stim_data( StimTypeHandler = Callable[[stim.Circuit, cirq.Gate, List[int], str], None] +StimOpTypeHandler = Callable[[stim.Circuit, cirq.Operation, List[int], str, List[Tuple[str, int]]], None] + + +def _stim_append_classically_controlled_gate( + circuit: stim.Circuit, + op: cirq.ClassicallyControlledOperation, + targets: List[int], + tag: str, + measurement_key_lengths: List[Tuple[str, int]]): + + if len(op.classical_controls) != 1: + raise NotImplementedError(f'Stim only supports single-control Pauli feedback, but got {op=}') + control, = op.classical_controls + if not isinstance(control, cirq.KeyCondition): + raise NotImplementedError(f'Stim only supports single-control Pauli feedback (i.e. a `cirq.KeyCondition` control), but got {control=}') + control: cirq.KeyCondition + gate = op.without_classical_controls().gate + + if gate == cirq.X: + stim_gate = 'X' + elif gate == cirq.Y: + stim_gate = 'Y' + elif gate == cirq.Z: + stim_gate = 'Z' + else: + raise NotImplementedError(f'Stim only supports Pauli feedback, but got {op=}') + assert len(targets) == 1 + + skips_left = control.index + for offset in range(len(measurement_key_lengths)): + m_key, m_len = measurement_key_lengths[-1 - offset] + if m_len != 1: + raise NotImplementedError(f"multi-qubit measurement {m_key!r}") + if m_key == control.key: + if skips_left > 0: + skips_left -= 1 + else: + rec_target = stim.target_rec(-1 - offset) + break + else: + raise ValueError( + f"{control!r} was processed before the measurement it referenced." + f" Make sure the referenced measurements keys are actually in the circuit, and come" + f" in an earlier moment (or earlier in the same moment's operation order)." + ) + circuit.append(f"C{stim_gate}", [rec_target, targets[0]], tag=tag) @functools.lru_cache(maxsize=1) @@ -278,6 +322,14 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: } +@functools.lru_cache() +def op_type_to_stim_append_func() -> Dict[Type[cirq.Operation], StimOpTypeHandler]: + """A dictionary mapping specific gate types to stim circuit appending functions.""" + return { + cirq.ClassicallyControlledOperation: _stim_append_classically_controlled_gate, + } + + def _stim_append_measurement_gate( circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int], tag: str ): @@ -454,7 +506,8 @@ def process_circuit_operation_into_repeat_block(self, op: cirq.CircuitOperation, def process_operations(self, operations: Iterable[cirq.Operation]) -> None: g2f = gate_to_stim_append_func() - t2f = gate_type_to_stim_append_func() + tg2f = gate_type_to_stim_append_func() + to2f = op_type_to_stim_append_func() for op in operations: assert isinstance(op, cirq.Operation) tag = self.tag_func(op) @@ -500,11 +553,16 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None: continue # Look for recognized gate types like cirq.DepolarizingChannel. - type_append_func = t2f.get(type(gate)) + type_append_func = tg2f.get(type(gate)) if type_append_func is not None: type_append_func(self.out, gate, targets, tag=tag) continue + op_type_append_func = to2f.get(type(op)) + if op_type_append_func is not None: + op_type_append_func(self.out, op, targets, tag, self.key_out) + continue + # Ask unrecognized operations to decompose themselves into simpler operations. try: self.process_operations(cirq.decompose_once(op)) diff --git a/glue/cirq/stimcirq/_feedback_pauli.py b/glue/cirq/stimcirq/_feedback_pauli.py new file mode 100644 index 00000000..bd33ce43 --- /dev/null +++ b/glue/cirq/stimcirq/_feedback_pauli.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, List, Tuple, Optional + +import cirq +import stim + + +@cirq.value_equality +class FeedbackPauli(cirq.Gate): + """A Pauli gate conditioned on a prior measurement.""" + + def __init__( + self, + *, + relative_measurement_index: Optional[int] = None, + pauli: cirq.Pauli, + ): + r""" + + Args: + relative_measurement_index: A negative integer identifying how many measurements ago is the measurement that + controls the Pauli operation. + pauli: The cirq Pauli operation to apply when the bit is True. + """ + if relative_measurement_index is not None and (relative_measurement_index >= 0 or not isinstance(relative_measurement_index, int)): + raise ValueError(f"{relative_measurement_index=} isn't a negative int (note {type(relative_measurement_index)=})") + self.relative_measurement_index = relative_measurement_index + self.pauli = pauli + + def _is_parameterized_(self) -> bool: + return False + + def _num_qubits_(self) -> int: + return 1 + + def _value_equality_values_(self) -> Any: + return self.pauli, self.relative_measurement_index + + def _circuit_diagram_info_(self, args: Any) -> str: + return f"{self.pauli}^rec[{self.relative_measurement_index}]" + + @staticmethod + def _json_namespace_() -> str: + return '' + + def _json_dict_(self) -> Dict[str, Any]: + return { + 'pauli': self.pauli, + 'relative_measurement_index': self.relative_measurement_index, + } + + def __repr__(self) -> str: + return ( + f'stimcirq.FeedbackPauli(' + f'relative_measurement_index={self.relative_measurement_index!r}, ' + f'pauli={self.pauli!r})' + ) + + def _stim_conversion_( + self, + *, + edit_circuit: stim.Circuit, + tag: str, + targets: List[int], + **kwargs, + ): + rec_target = stim.target_rec(self.relative_measurement_index) + edit_circuit.append(f"C{self.pauli}", [rec_target, targets[0]], tag=tag) diff --git a/glue/cirq/stimcirq/_feedback_pauli_test.py b/glue/cirq/stimcirq/_feedback_pauli_test.py new file mode 100644 index 00000000..2fc97f5a --- /dev/null +++ b/glue/cirq/stimcirq/_feedback_pauli_test.py @@ -0,0 +1,209 @@ +import cirq +import pytest +import stim +import stimcirq + + +def test_cirq_to_stim_to_cirq_classical_control(): + q = cirq.LineQubit(0) + cirq_circuit = cirq.Circuit( + cirq.measure(q, key="test"), + cirq.X(q).with_classical_controls("test").with_tags("test2") + ) + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) + assert stim_circuit == stim.Circuit(""" + M 0 + TICK + CX[test2] rec[-1] 0 + TICK + """) + assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq.Circuit( + cirq.measure(q, key="0"), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags("test2") + ) + + +def test_cirq_to_stim_to_cirq_feedback_pauli(): + q = cirq.LineQubit(0) + cirq_circuit = cirq.Circuit( + cirq.measure(q, key="test"), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags('test3') + ) + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) + assert stim_circuit == stim.Circuit(""" + M 0 + TICK + CX[test3] rec[-1] 0 + TICK + """) + assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq.Circuit( + cirq.measure(q, key="0"), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags('test3') + ) + + +def test_stim_to_cirq_conversion(): + with pytest.raises(NotImplementedError, match="wrong target"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + XCZ rec[-1] 3 + """)) + with pytest.raises(NotImplementedError, match="wrong target"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + YCZ rec[-1] 3 + """)) + with pytest.raises(NotImplementedError, match="wrong target"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + CY 3 rec[-1] + """)) + with pytest.raises(NotImplementedError, match="wrong target"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + CX 3 rec[-1] + """)) + with pytest.raises(NotImplementedError, match="Two classical"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 1 + TICK + CZ rec[-1] rec[-2] + """)) + + assert stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + ZCX rec[-1] 0 + ZCY rec[-1] 1 + ZCZ rec[-1] 2 + XCZ 3 rec[-1] + YCZ 4 rec[-1] + ZCZ 5 rec[-1] + """)) == cirq.Circuit( + cirq.Moment( + cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='0')), + ), + cirq.Moment( + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(cirq.LineQubit(0)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y).on(cirq.LineQubit(1)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Z).on(cirq.LineQubit(2)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(cirq.LineQubit(3)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y).on(cirq.LineQubit(4)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Z).on(cirq.LineQubit(5)), + ), + ) + + +def test_stim_conversion(): + a, b, c = cirq.LineQubit.range(3) + + with pytest.raises(ValueError, match="earlier"): + stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit(cirq.Moment(cirq.X(a).with_classical_controls("unknown"))) + ) + with pytest.raises(ValueError, match="earlier"): + stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit( + cirq.Moment( + cirq.X(a).with_classical_controls("unknown"), cirq.measure(b, key="later") + ) + ) + ) + with pytest.raises(ValueError, match="earlier"): + stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit( + cirq.Moment(cirq.X(a).with_classical_controls("unknown")), + cirq.Moment(cirq.measure(b, key="later")), + ) + ) + assert stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit( + cirq.Moment(cirq.measure(b, key="earlier")), + cirq.Moment(cirq.X(b).with_classical_controls("earlier")), + ) + ) == stim.Circuit( + """ + QUBIT_COORDS(1) 0 + M 0 + TICK + CX rec[-1] 0 + TICK + """ + ) + + assert stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit( + cirq.Moment(cirq.measure(a, key="a"), cirq.measure(b, key="b")), + cirq.Moment( + cirq.X(b).with_classical_controls("a"), + ), + cirq.Moment( + cirq.Z(b).with_classical_controls("b"), + ), + ) + ) == stim.Circuit( + """ + M 0 1 + TICK + CX rec[-2] 1 + TICK + CZ rec[-1] 1 + TICK + """ + ) + + +def test_diagram(): + a, b = cirq.LineQubit.range(2) + cirq.testing.assert_has_diagram( + cirq.Circuit( + cirq.measure(a, key="a"), + cirq.measure(b, key="b"), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli='Y').on(a), + ), + """ +0: ---M('a')---Y^rec[-1]--- + +1: ---M('b')--------------- + """, + use_unicode_characters=False, + ) + + +def test_repr(): + val = stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y) + assert eval(repr(val), {"cirq": cirq, "stimcirq": stimcirq}) == val + + +def test_equality(): + eq = cirq.testing.EqualsTester() + eq.add_equality_group( + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X)) + eq.add_equality_group(stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y)) + eq.add_equality_group( + stimcirq.FeedbackPauli(relative_measurement_index=-4, pauli=cirq.X), + ) + eq.add_equality_group(stimcirq.FeedbackPauli(relative_measurement_index=-10, pauli=cirq.Z)) + + +def test_json_serialization(): + c = cirq.Circuit( + stimcirq.FeedbackPauli(relative_measurement_index=-3, pauli=cirq.X).on(cirq.LineQubit(0)), + stimcirq.FeedbackPauli(relative_measurement_index=-5, pauli=cirq.Y).on(cirq.LineQubit(1)), + stimcirq.FeedbackPauli(relative_measurement_index=-7, pauli=cirq.Z).on(cirq.LineQubit(2)), + ) + json = cirq.to_json(c) + c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) + assert c == c2 + + +def test_json_backwards_compat_exact(): + raw = stimcirq.FeedbackPauli(relative_measurement_index=-3, pauli=cirq.X) + packed = '{\n "cirq_type": "FeedbackPauli",\n "pauli": {\n "cirq_type": "_PauliX",\n "exponent": 1.0,\n "global_shift": 0.0\n },\n "relative_measurement_index": -3\n}' + assert cirq.to_json(raw) == packed + assert cirq.read_json(json_text=packed, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw diff --git a/glue/cirq/stimcirq/_stim_to_cirq.py b/glue/cirq/stimcirq/_stim_to_cirq.py index 543a8b2d..39239183 100644 --- a/glue/cirq/stimcirq/_stim_to_cirq.py +++ b/glue/cirq/stimcirq/_stim_to_cirq.py @@ -26,6 +26,7 @@ from ._obs_annotation import CumulativeObservableAnnotation from ._shift_coords_annotation import ShiftCoordsAnnotation from ._sweep_pauli import SweepPauli +from ._feedback_pauli import FeedbackPauli def _stim_targets_to_dense_pauli_string( @@ -414,9 +415,11 @@ def __call__( tracker.process_gate_instruction(gate=self.gate, instruction=instruction) class SweepableGateHandler: - def __init__(self, pauli_gate: cirq.Pauli, gate: cirq.Gate): + def __init__(self, pauli_gate: cirq.Pauli, gate: cirq.Gate, allow_first: bool, allow_second: bool): self.pauli_gate = pauli_gate self.gate = gate + self.allow_first = allow_first + self.allow_second = allow_second def __call__( self, tracker: 'CircuitTranslationTracker', instruction: stim.CircuitInstruction @@ -429,8 +432,12 @@ def __call__( for k in range(0, len(targets), 2): a = targets[k] b = targets[k + 1] + if not a.is_qubit_target and not self.allow_first: + raise NotImplementedError(f"Classical control is on the wrong target: instruction={instruction!r}") + if not b.is_qubit_target and not self.allow_second: + raise NotImplementedError(f"Classical control is on the wrong target: instruction={instruction!r}") if not a.is_qubit_target and not b.is_qubit_target: - raise NotImplementedError(f"instruction={instruction!r}") + raise NotImplementedError(f"Two classical controls: instruction={instruction!r}") if a.is_sweep_bit_target or b.is_sweep_bit_target: if b.is_sweep_bit_target: a, b = b, a @@ -442,6 +449,16 @@ def __call__( pauli=self.pauli_gate, ).on(cirq.LineQubit(b.value)).with_tags(*tags) ) + elif a.is_measurement_record_target or b.is_measurement_record_target: + if b.is_measurement_record_target: + a, b = b, a + assert not a.is_inverted_result_target + tracker.append_operation( + FeedbackPauli( + relative_measurement_index=a.value, + pauli=self.pauli_gate, + ).on(cirq.LineQubit(b.value)).with_tags(*tags) + ) else: if not a.is_qubit_target or not b.is_qubit_target: raise NotImplementedError(f"instruction={instruction!r}") @@ -592,17 +609,17 @@ def handler( "ISWAP_DAG": gate(cirq.ISWAP ** -1), "XCX": gate(cirq.PauliInteractionGate(cirq.X, False, cirq.X, False)), "XCY": gate(cirq.PauliInteractionGate(cirq.X, False, cirq.Y, False)), - "XCZ": sweep_gate(cirq.X, cirq.PauliInteractionGate(cirq.X, False, cirq.Z, False)), + "XCZ": sweep_gate(cirq.X, cirq.PauliInteractionGate(cirq.X, False, cirq.Z, False), False, True), "YCX": gate(cirq.PauliInteractionGate(cirq.Y, False, cirq.X, False)), "YCY": gate(cirq.PauliInteractionGate(cirq.Y, False, cirq.Y, False)), - "YCZ": sweep_gate(cirq.Y, cirq.PauliInteractionGate(cirq.Y, False, cirq.Z, False)), - "CX": sweep_gate(cirq.X, cirq.CNOT), - "CNOT": sweep_gate(cirq.X, cirq.CNOT), - "ZCX": sweep_gate(cirq.X, cirq.CNOT), - "CY": sweep_gate(cirq.Y, cirq.Y.controlled(1)), - "ZCY": sweep_gate(cirq.Y, cirq.Y.controlled(1)), - "CZ": sweep_gate(cirq.Z, cirq.CZ), - "ZCZ": sweep_gate(cirq.Z, cirq.CZ), + "YCZ": sweep_gate(cirq.Y, cirq.PauliInteractionGate(cirq.Y, False, cirq.Z, False), False, True), + "CX": sweep_gate(cirq.X, cirq.CNOT, True, False), + "CNOT": sweep_gate(cirq.X, cirq.CNOT, True, False), + "ZCX": sweep_gate(cirq.X, cirq.CNOT, True, False), + "CY": sweep_gate(cirq.Y, cirq.Y.controlled(1), True, False), + "ZCY": sweep_gate(cirq.Y, cirq.Y.controlled(1), True, False), + "CZ": sweep_gate(cirq.Z, cirq.CZ, True, True), + "ZCZ": sweep_gate(cirq.Z, cirq.CZ, True, True), "DEPOLARIZE1": noise(lambda p: cirq.DepolarizingChannel(p, 1)), "DEPOLARIZE2": noise(lambda p: cirq.DepolarizingChannel(p, 2)), "X_ERROR": noise(cirq.X.with_probability),