Skip to content

Commit

Permalink
Provide options in the cut-finder API to turn LO gate and wire cut fi…
Browse files Browse the repository at this point in the history
…nding off or on, expose min-reached flag. (#586)

* enable only wire cut finding

* edit tests

* explore adding new flags

* Handle multiple arguments when cutting both wires

* black, mypy, remove the erroneous example I added to the tutorial.

* doc string

* update doc string

* update tests

* reorganise tests

* add cut both wires test

* add release note

* edit release note

* un-expose LOCC cost functions everywhere

* add min reached flag.

* edit release note

* Change to upper case in doc string

* Italicize

Co-authored-by: Jim Garrison <[email protected]>

* Edit bool in release note

Co-authored-by: Jim Garrison <[email protected]>

* Update reference in release note

Co-authored-by: Jim Garrison <[email protected]>

* Edit reference in release note

Co-authored-by: Jim Garrison <[email protected]>

* pull changes

---------

Co-authored-by: Jim Garrison <[email protected]>
  • Loading branch information
ibrahim-shehzad and garrison authored May 30, 2024
1 parent df3d536 commit e29d561
Show file tree
Hide file tree
Showing 13 changed files with 589 additions and 346 deletions.
11 changes: 10 additions & 1 deletion circuit_knitting/cutting/automated_cut_finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def find_cuts(
``data`` field.
- sampling_overhead: The sampling overhead incurred from cutting the specified
gates and wires.
- minimum_reached: A bool indicating whether or not the search conclusively found
the minimum of cost function. ``minimum_reached = False`` could also mean that the
cost returned was actually the lowest possible cost but that the search was
not allowed to run long enough to prove that this was the case.
Raises:
ValueError: The input circuit contains a gate acting on more than 2 qubits.
Expand All @@ -63,6 +67,8 @@ def find_cuts(
seed=optimization.seed,
max_gamma=optimization.max_gamma,
max_backjumps=optimization.max_backjumps,
gate_lo=optimization.gate_lo,
wire_lo=optimization.wire_lo,
)

# Hard-code the optimizer to an LO-only optimizer
Expand Down Expand Up @@ -106,7 +112,7 @@ def find_cuts(
)
counter += 1

if action.action.get_name() == "CutBothWires": # pragma: no cover
if action.action.get_name() == "CutBothWires":
# There should be two wires specified in the action in this case
assert len(action.args) == 2
qubit_id2 = action.args[1][0] - 1
Expand All @@ -126,6 +132,7 @@ def find_cuts(
elif inst.operation.name == "cut_wire":
metadata["cuts"].append(("Wire Cut", i))
metadata["sampling_overhead"] = opt_out.upper_bound_gamma() ** 2
metadata["minimum_reached"] = optimizer.minimum_reached()

return circ_out, metadata

Expand All @@ -137,6 +144,8 @@ class OptimizationParameters:
seed: int | None = OptimizationSettings().seed
max_gamma: float = OptimizationSettings().max_gamma
max_backjumps: None | int = OptimizationSettings().max_backjumps
gate_lo: bool = OptimizationSettings().gate_lo
wire_lo: bool = OptimizationSettings().wire_lo


@dataclass
Expand Down
11 changes: 8 additions & 3 deletions circuit_knitting/cutting/cut_finding/cut_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,16 @@ def cut_optimization_cost_func(


def cut_optimization_upper_bound_cost_func(
goal_state, func_args: CutOptimizationFuncArgs
goal_state: DisjointSubcircuitsState, func_args: CutOptimizationFuncArgs
) -> tuple[float, float]:
"""Return the value of :math:`gamma` computed assuming all LO cuts."""
# pylint: disable=unused-argument
return (goal_state.upper_bound_gamma(), np.inf)
if goal_state is not None:
return (goal_state.upper_bound_gamma(), np.inf)
else:
raise ValueError(
"None state encountered: no cut state satisfying the specified constraints and settings could be found."
)


def cut_optimization_min_cost_bound_func(
Expand Down Expand Up @@ -125,7 +130,7 @@ def cut_optimization_goal_state_func(
# Global variable that holds the search-space functions for generating
# the cut optimization search space.
cut_optimization_search_funcs = SearchFunctions(
cost_func=cut_optimization_cost_func,
cost_func=cut_optimization_upper_bound_cost_func, # valid choice when considering only LO cuts.
upperbound_cost_func=cut_optimization_upper_bound_cost_func,
next_state_func=cut_optimization_next_state_func,
goal_state_func=cut_optimization_goal_state_func,
Expand Down
24 changes: 13 additions & 11 deletions circuit_knitting/cutting/cut_finding/disjoint_subcircuits_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,20 @@ class Action(NamedTuple):
args: list | tuple


class GateCutLocation(NamedTuple):
"""Named tuple for specification of gate cut location."""
class CutLocation(NamedTuple):
"""Named tuple for specifying cut locations.
This is used to specify instances of both :class:`CutTwoQubitGate` and :class:`CutBothWires`.
Both of these instances are fully specified by a gate reference.
"""

instruction_id: int
gate_name: str
qubits: Sequence


class WireCutLocation(NamedTuple):
"""Named tuple for specification of wire cut location.
"""Named tuple for specification of (single) wire cut locations.
Wire cuts are identified through the gates whose input wires are cut.
"""
Expand All @@ -64,10 +68,10 @@ class CutIdentifier(NamedTuple):
"""Named tuple for specification of location of :class:`CutTwoQubitGate` or :class:`CutBothWires` instances."""

cut_action: DisjointSearchAction
gate_cut_location: GateCutLocation
cut_location: CutLocation


class OneWireCutIdentifier(NamedTuple):
class SingleWireCutIdentifier(NamedTuple):
"""Named tuple for specification of location of :class:`CutLeftWire` or :class:`CutRightWire` instances."""

cut_action: DisjointSearchAction
Expand Down Expand Up @@ -130,15 +134,13 @@ def __init__(self, num_qubits: int | None = None, max_wire_cuts: int | None = No
if not (
num_qubits is None or (isinstance(num_qubits, int) and num_qubits >= 0)
):
raise ValueError("num_qubits must be either be None or a positive integer.")
raise ValueError("num_qubits must either be None or a positive integer.")

if not (
max_wire_cuts is None
or (isinstance(max_wire_cuts, int) and max_wire_cuts >= 0)
):
raise ValueError(
"max_wire_cuts must be either be None or a positive integer."
)
raise ValueError("max_wire_cuts must either be None or a positive integer.")

if num_qubits is None or max_wire_cuts is None:
self.wiremap: NDArray[np.int_] | None = None
Expand Down Expand Up @@ -213,7 +215,7 @@ def cut_actions_sublist(self) -> list[NamedTuple]:
for i in range(len(cut_actions)):
if cut_actions[i].action.get_name() in ("CutLeftWire", "CutRightWire"):
self.cut_actions_list.append(
OneWireCutIdentifier(
SingleWireCutIdentifier(
cut_actions[i].action.get_name(),
WireCutLocation(
cut_actions[i].gate_spec.instruction_id,
Expand All @@ -231,7 +233,7 @@ def cut_actions_sublist(self) -> list[NamedTuple]:
self.cut_actions_list.append(
CutIdentifier(
cut_actions[i].action.get_name(),
GateCutLocation(
CutLocation(
cut_actions[i].gate_spec.instruction_id,
cut_actions[i].gate_spec.gate.name,
cut_actions[i].gate_spec.gate.qubits,
Expand Down
36 changes: 19 additions & 17 deletions circuit_knitting/cutting/cut_finding/optimization_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ class OptimizationSettings:
max_gamma: float = 1024
max_backjumps: None | int = 10000
seed: int | None = None
LO: bool = True
LOCC_ancillas: bool = False
LOCC_no_ancillas: bool = False
gate_lo: bool = True
wire_lo: bool = True
gate_locc_ancillas: bool = False
wire_locc_ancillas: bool = False
wire_locc_no_ancillas: bool = False
engine_selections: dict[str, str] | None = None

def __post_init__(self):
Expand All @@ -57,12 +59,12 @@ def __post_init__(self):
if self.max_backjumps is not None and self.max_backjumps < 0:
raise ValueError("max_backjumps must be a positive semi-definite integer.")

self.gate_cut_LO = self.LO
self.gate_cut_LOCC_with_ancillas = self.LOCC_ancillas
self.gate_cut_lo = self.gate_lo
self.gate_cut_locc_with_ancillas = self.gate_locc_ancillas

self.wire_cut_LO = self.LO
self.wire_cut_LOCC_with_ancillas = self.LOCC_ancillas
self.wire_cut_LOCC_no_ancillas = self.LOCC_no_ancillas
self.wire_cut_lo = self.wire_lo
self.wire_cut_locc_with_ancillas = self.wire_locc_ancillas
self.wire_cut_locc_no_ancillas = self.wire_locc_no_ancillas
if self.engine_selections is None:
self.engine_selections = {"CutOptimization": "BestFirst"}

Expand Down Expand Up @@ -102,31 +104,31 @@ def set_gate_cut_types(self) -> None:
The default is to only include LO gate cuts, which are the
only cut types supported in this release.
"""
self.gate_cut_LO = self.LO
self.gate_cut_LOCC_with_ancillas = self.LOCC_ancillas
self.gate_cut_lo = self.gate_lo
self.gate_cut_locc_with_ancillas = self.gate_locc_ancillas

def set_wire_cut_types(self) -> None:
"""Select which wire-cut types to include in the optimization.
The default is to only include LO wire cuts, which are the
only cut types supported in this release.
"""
self.wire_cut_LO = self.LO
self.wire_cut_LOCC_with_ancillas = self.LOCC_ancillas
self.wire_cut_LOCC_no_ancillas = self.LOCC_no_ancillas
self.wire_cut_lo = self.wire_lo
self.wire_cut_locc_with_ancillas = self.wire_locc_ancillas
self.wire_cut_locc_no_ancillas = self.wire_locc_no_ancillas

def get_cut_search_groups(self) -> list[None | str]:
"""Return a list of action groups to include in the optimization."""
out: list
out = [None]

if self.gate_cut_LO or self.gate_cut_LOCC_with_ancillas:
if self.gate_cut_lo or self.gate_cut_locc_with_ancillas:
out.append("GateCut")

if (
self.wire_cut_LO
or self.wire_cut_LOCC_with_ancillas
or self.wire_cut_LOCC_no_ancillas
self.wire_cut_lo
or self.wire_cut_locc_with_ancillas
or self.wire_cut_locc_no_ancillas
):
out.append("WireCut")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
"version": "3.9.6"
}
},
"nbformat": 4,
Expand Down
38 changes: 20 additions & 18 deletions docs/circuit_cutting/tutorials/04_automatic_cut_finding.ipynb

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions releasenotes/notes/min-reached-finder-flag-aa6dd9021e165f80.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
features:
- |
A new ``minimum_reached`` field has been added to the metadata outputted by :func:`circuit_knitting.cutting.find_cuts` to check if the cut-finder found
a cut scheme that minimized the sampling overhead. Note that the search algorithm employed by the cut-finder is *guaranteed* to find
the optimal solution, that is, the solution with the minimum sampling overhead, provided it is allowed to run long enough.
The user is free to time-restrict the search by passing in suitable values for ``max_backjumps`` and/or ``max_gamma`` to
:class:`.OptimizationParameters`. If the search is terminated prematurely in this way, the metadata may indicate that the minimum
was not reached, even though the returned solution `was` actually the optimal solution. This would mean that the search that was performed was not
exhaustive enough to prove that the returned solution was optimal.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
When specifying instances of :class:`.OptimizationParameters` that are inputted to :meth:`circuit_knitting.cutting.find_cuts()`, the user can now control whether the
cut-finder looks only for gate cuts, only for wire cuts, or both, by setting the bools ``gate_lo`` and ``wire_lo`` appropriately. The default value
of both of these is set to ``True`` and so the default search considers the possibility of both gate and wire cuts.
3 changes: 1 addition & 2 deletions test/cutting/cut_finding/test_best_first_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,10 @@ def test_best_first_search(test_circuit: SimpleGateList):
op = CutOptimization(test_circuit, settings, constraint_obj)

out, _ = op.optimization_pass()

assert op.search_engine.get_stats(penultimate=True) is not None
assert op.search_engine.get_stats() is not None
assert op.get_upperbound_cost() == (27, inf)
assert op.minimum_reached() is False
assert op.minimum_reached() is True
assert out is not None
assert (out.lower_bound_gamma(), out.gamma_UB, out.get_max_width()) == (
27,
Expand Down
Loading

0 comments on commit e29d561

Please sign in to comment.