Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions glue/cirq/stimcirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ._cx_swap_gate import CXSwapGate
from ._cz_swap_gate import CZSwapGate
from ._det_annotation import DetAnnotation
from ._feedback_pauli import FeedbackPauli
from ._obs_annotation import CumulativeObservableAnnotation
from ._shift_coords_annotation import ShiftCoordsAnnotation
from ._stim_sampler import StimSampler
Expand All @@ -19,6 +20,7 @@
JSON_RESOLVERS_DICT = {
"CumulativeObservableAnnotation": CumulativeObservableAnnotation,
"DetAnnotation": DetAnnotation,
"FeedbackPauli": FeedbackPauli,
"MeasureAndOrResetGate": MeasureAndOrResetGate,
"ShiftCoordsAnnotation": ShiftCoordsAnnotation,
"SweepPauli": SweepPauli,
Expand Down
66 changes: 62 additions & 4 deletions glue/cirq/stimcirq/_cirq_to_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import cirq
import stim

from ._i_error_gate import IErrorGate
from ._ii_error_gate import IIErrorGate
from ._ii_gate import IIGate


Expand Down Expand Up @@ -142,6 +140,52 @@ def cirq_circuit_to_stim_data(


StimTypeHandler = Callable[[stim.Circuit, cirq.Gate, List[int], str], None]
StimOpTypeHandler = Callable[[stim.Circuit, cirq.Operation, List[int], str, List[Tuple[str, int]]], None]


def _stim_append_classically_controlled_gate(
circuit: stim.Circuit,
op: cirq.ClassicallyControlledOperation,
targets: List[int],
tag: str,
measurement_key_lengths: List[Tuple[str, int]]):

if len(op.classical_controls) != 1:
raise NotImplementedError(f'Stim only supports single-control Pauli feedback, but got {op=}')
control, = op.classical_controls
if not isinstance(control, cirq.KeyCondition):
raise NotImplementedError(f'Stim only supports single-control Pauli feedback (i.e. a `cirq.KeyCondition` control), but got {control=}')
control: cirq.KeyCondition
gate = op.without_classical_controls().gate

if gate == cirq.X:
stim_gate = 'X'
elif gate == cirq.Y:
stim_gate = 'Y'
elif gate == cirq.Z:
stim_gate = 'Z'
else:
raise NotImplementedError(f'Stim only supports Pauli feedback, but got {op=}')
assert len(targets) == 1

skips_left = control.index
for offset in range(len(measurement_key_lengths)):
m_key, m_len = measurement_key_lengths[-1 - offset]
if m_len != 1:
raise NotImplementedError(f"multi-qubit measurement {m_key!r}")
if m_key == control.key:
if skips_left > 0:
skips_left -= 1
else:
rec_target = stim.target_rec(-1 - offset)
break
else:
raise ValueError(
f"{control!r} was processed before the measurement it referenced."
f" Make sure the referenced measurements keys are actually in the circuit, and come"
f" in an earlier moment (or earlier in the same moment's operation order)."
)
circuit.append(f"C{stim_gate}", [rec_target, targets[0]], tag=tag)


@functools.lru_cache(maxsize=1)
Expand Down Expand Up @@ -278,6 +322,14 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]:
}


@functools.lru_cache()
def op_type_to_stim_append_func() -> Dict[Type[cirq.Operation], StimOpTypeHandler]:
"""A dictionary mapping specific gate types to stim circuit appending functions."""
return {
cirq.ClassicallyControlledOperation: _stim_append_classically_controlled_gate,
}


def _stim_append_measurement_gate(
circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int], tag: str
):
Expand Down Expand Up @@ -454,7 +506,8 @@ def process_circuit_operation_into_repeat_block(self, op: cirq.CircuitOperation,

def process_operations(self, operations: Iterable[cirq.Operation]) -> None:
g2f = gate_to_stim_append_func()
t2f = gate_type_to_stim_append_func()
tg2f = gate_type_to_stim_append_func()
to2f = op_type_to_stim_append_func()
for op in operations:
assert isinstance(op, cirq.Operation)
tag = self.tag_func(op)
Expand Down Expand Up @@ -500,11 +553,16 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None:
continue

# Look for recognized gate types like cirq.DepolarizingChannel.
type_append_func = t2f.get(type(gate))
type_append_func = tg2f.get(type(gate))
if type_append_func is not None:
type_append_func(self.out, gate, targets, tag=tag)
continue

op_type_append_func = to2f.get(type(op))
if op_type_append_func is not None:
op_type_append_func(self.out, op, targets, tag, self.key_out)
continue

# Ask unrecognized operations to decompose themselves into simpler operations.
try:
self.process_operations(cirq.decompose_once(op))
Expand Down
67 changes: 67 additions & 0 deletions glue/cirq/stimcirq/_feedback_pauli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Any, Dict, List, Tuple, Optional

import cirq
import stim


@cirq.value_equality
class FeedbackPauli(cirq.Gate):
"""A Pauli gate conditioned on a prior measurement."""

def __init__(
self,
*,
relative_measurement_index: Optional[int] = None,
pauli: cirq.Pauli,
):
r"""

Args:
relative_measurement_index: A negative integer identifying how many measurements ago is the measurement that
controls the Pauli operation.
pauli: The cirq Pauli operation to apply when the bit is True.
"""
if relative_measurement_index is not None and (relative_measurement_index >= 0 or not isinstance(relative_measurement_index, int)):
raise ValueError(f"{relative_measurement_index=} isn't a negative int (note {type(relative_measurement_index)=})")
self.relative_measurement_index = relative_measurement_index
self.pauli = pauli

def _is_parameterized_(self) -> bool:
return False

def _num_qubits_(self) -> int:
return 1

def _value_equality_values_(self) -> Any:
return self.pauli, self.relative_measurement_index

def _circuit_diagram_info_(self, args: Any) -> str:
return f"{self.pauli}^rec[{self.relative_measurement_index}]"

@staticmethod
def _json_namespace_() -> str:
return ''

def _json_dict_(self) -> Dict[str, Any]:
return {
'pauli': self.pauli,
'relative_measurement_index': self.relative_measurement_index,
}

def __repr__(self) -> str:
return (
f'stimcirq.FeedbackPauli('
f'relative_measurement_index={self.relative_measurement_index!r}, '
f'pauli={self.pauli!r})'
)

def _stim_conversion_(
self,
*,
edit_circuit: stim.Circuit,
tag: str,
targets: List[int],
**kwargs,
):
rec_target = stim.target_rec(self.relative_measurement_index)
edit_circuit.append(f"C{self.pauli}", [rec_target, targets[0]], tag=tag)
209 changes: 209 additions & 0 deletions glue/cirq/stimcirq/_feedback_pauli_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import cirq
import pytest
import stim
import stimcirq


def test_cirq_to_stim_to_cirq_classical_control():
q = cirq.LineQubit(0)
cirq_circuit = cirq.Circuit(
cirq.measure(q, key="test"),
cirq.X(q).with_classical_controls("test").with_tags("test2")
)
stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit)
assert stim_circuit == stim.Circuit("""
M 0
TICK
CX[test2] rec[-1] 0
TICK
""")
assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq.Circuit(
cirq.measure(q, key="0"),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags("test2")
)


def test_cirq_to_stim_to_cirq_feedback_pauli():
q = cirq.LineQubit(0)
cirq_circuit = cirq.Circuit(
cirq.measure(q, key="test"),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags('test3')
)
stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit)
assert stim_circuit == stim.Circuit("""
M 0
TICK
CX[test3] rec[-1] 0
TICK
""")
assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq.Circuit(
cirq.measure(q, key="0"),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags('test3')
)


def test_stim_to_cirq_conversion():
with pytest.raises(NotImplementedError, match="wrong target"):
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
M 0
TICK
XCZ rec[-1] 3
"""))
with pytest.raises(NotImplementedError, match="wrong target"):
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
M 0
TICK
YCZ rec[-1] 3
"""))
with pytest.raises(NotImplementedError, match="wrong target"):
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
M 0
TICK
CY 3 rec[-1]
"""))
with pytest.raises(NotImplementedError, match="wrong target"):
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
M 0
TICK
CX 3 rec[-1]
"""))
with pytest.raises(NotImplementedError, match="Two classical"):
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
M 0 1
TICK
CZ rec[-1] rec[-2]
"""))

assert stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
M 0
TICK
ZCX rec[-1] 0
ZCY rec[-1] 1
ZCZ rec[-1] 2
XCZ 3 rec[-1]
YCZ 4 rec[-1]
ZCZ 5 rec[-1]
""")) == cirq.Circuit(
cirq.Moment(
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='0')),
),
cirq.Moment(
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(cirq.LineQubit(0)),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y).on(cirq.LineQubit(1)),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Z).on(cirq.LineQubit(2)),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(cirq.LineQubit(3)),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y).on(cirq.LineQubit(4)),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Z).on(cirq.LineQubit(5)),
),
)


def test_stim_conversion():
a, b, c = cirq.LineQubit.range(3)

with pytest.raises(ValueError, match="earlier"):
stimcirq.cirq_circuit_to_stim_circuit(
cirq.Circuit(cirq.Moment(cirq.X(a).with_classical_controls("unknown")))
)
with pytest.raises(ValueError, match="earlier"):
stimcirq.cirq_circuit_to_stim_circuit(
cirq.Circuit(
cirq.Moment(
cirq.X(a).with_classical_controls("unknown"), cirq.measure(b, key="later")
)
)
)
with pytest.raises(ValueError, match="earlier"):
stimcirq.cirq_circuit_to_stim_circuit(
cirq.Circuit(
cirq.Moment(cirq.X(a).with_classical_controls("unknown")),
cirq.Moment(cirq.measure(b, key="later")),
)
)
assert stimcirq.cirq_circuit_to_stim_circuit(
cirq.Circuit(
cirq.Moment(cirq.measure(b, key="earlier")),
cirq.Moment(cirq.X(b).with_classical_controls("earlier")),
)
) == stim.Circuit(
"""
QUBIT_COORDS(1) 0
M 0
TICK
CX rec[-1] 0
TICK
"""
)

assert stimcirq.cirq_circuit_to_stim_circuit(
cirq.Circuit(
cirq.Moment(cirq.measure(a, key="a"), cirq.measure(b, key="b")),
cirq.Moment(
cirq.X(b).with_classical_controls("a"),
),
cirq.Moment(
cirq.Z(b).with_classical_controls("b"),
),
)
) == stim.Circuit(
"""
M 0 1
TICK
CX rec[-2] 1
TICK
CZ rec[-1] 1
TICK
"""
)


def test_diagram():
a, b = cirq.LineQubit.range(2)
cirq.testing.assert_has_diagram(
cirq.Circuit(
cirq.measure(a, key="a"),
cirq.measure(b, key="b"),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli='Y').on(a),
),
"""
0: ---M('a')---Y^rec[-1]---

1: ---M('b')---------------
""",
use_unicode_characters=False,
)


def test_repr():
val = stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y)
assert eval(repr(val), {"cirq": cirq, "stimcirq": stimcirq}) == val


def test_equality():
eq = cirq.testing.EqualsTester()
eq.add_equality_group(
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X),
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X))
eq.add_equality_group(stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y))
eq.add_equality_group(
stimcirq.FeedbackPauli(relative_measurement_index=-4, pauli=cirq.X),
)
eq.add_equality_group(stimcirq.FeedbackPauli(relative_measurement_index=-10, pauli=cirq.Z))


def test_json_serialization():
c = cirq.Circuit(
stimcirq.FeedbackPauli(relative_measurement_index=-3, pauli=cirq.X).on(cirq.LineQubit(0)),
stimcirq.FeedbackPauli(relative_measurement_index=-5, pauli=cirq.Y).on(cirq.LineQubit(1)),
stimcirq.FeedbackPauli(relative_measurement_index=-7, pauli=cirq.Z).on(cirq.LineQubit(2)),
)
json = cirq.to_json(c)
c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER])
assert c == c2


def test_json_backwards_compat_exact():
raw = stimcirq.FeedbackPauli(relative_measurement_index=-3, pauli=cirq.X)
packed = '{\n "cirq_type": "FeedbackPauli",\n "pauli": {\n "cirq_type": "_PauliX",\n "exponent": 1.0,\n "global_shift": 0.0\n },\n "relative_measurement_index": -3\n}'
assert cirq.to_json(raw) == packed
assert cirq.read_json(json_text=packed, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
Loading
Loading