-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathansatz.py
181 lines (135 loc) · 5.89 KB
/
ansatz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
JAX Ansatze
==========
File containing circuit ansatze as JAX-compatible functions
"""
from __future__ import annotations
import jax
import numpy
import qujax
#################################
def _zero_ket(L : int):
a = jax.numpy.zeros(2**L, dtype=complex)
a = a.at[0].set(1)
a = a.reshape((2,)*L)
return a
_zk2 = _zero_ket(1).reshape(-1)
def _linear_IQP_ansatz_2q(initial_word_state=_zk2):
"""
Returns a function that takes a set of angles and evaluates
the `IQPAnsatz` with one qubit and one layer
The angles should be supplied as a `jax.numpy.ndarray`.
Padding is represented as a row having all zeros.
"""
import jax.numpy as jnp
H = 1/jnp.sqrt(2)*jnp.array([[1, 1], [1, -1]])
HH = jnp.kron(H, H)
Id_local = jnp.eye(2)
proj = jnp.kron(jnp.array([[1., 0.], [0., 0.]]), Id_local)
Id = jnp.eye(4)
def _crz(t):
t = t * (2*numpy.pi)
return jnp.diag(jnp.array([1, jnp.exp(-1j * t),
1, jnp.exp(1j * t)]))
def _rz(t):
t = t * (2*numpy.pi)
return jnp.diag(jnp.array([jnp.exp(-1j*t), jnp.exp(1j*t)]))
def _rx(t):
t = t * (2*numpy.pi)
return jnp.array([[jnp.cos(t), -1j*jnp.sin(t)],
[-1j*jnp.sin(t), jnp.cos(t)]])
def _IQP_2q_word_circuit(left, angles):
x = angles
res = _rx(x[2]) @ _rz(x[1]) @ _rx(x[0]) @ initial_word_state
return res
def _IQP_2q_combining_circuit(left, right, angles):
# Flag indicating whether current word is padding
x = angles
res = jnp.kron(right, left)
res = HH @ res
res = _crz(x[3]) @ res
res = proj @ res
res *= 1/jnp.linalg.norm(res)
return res.reshape(2, 2)[0, :]
return initial_word_state, _IQP_2q_word_circuit, _IQP_2q_combining_circuit
def _trivial_combine(left, right, angles):
return right
def _hardware_efficient_ansatz(n_qubits, layers):
circuit_gates, circuit_qubit_inds, circuit_params_inds = [], [], []
circuit_gates += ['Ry'] * n_qubits
circuit_qubit_inds += [[n_qubits - 1 - i] for i in range(n_qubits)]
circuit_params_inds += [[i] for i in range(n_qubits)]
circuit_qubit_inds += [[n_qubits - 1 - i] for i in range(n_qubits)]
circuit_params_inds += [[i + len(circuit_gates)] for i in range(n_qubits)]
circuit_gates += ['Rz'] * n_qubits
circuit_gates += ['CX'] * (n_qubits - 1)
circuit_qubit_inds += [[n_qubits - 1 - i, n_qubits - 2 - i]
for i in range(n_qubits - 1)]
circuit_params_inds += [[]] * (n_qubits - 1)
param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
circuit_qubit_inds,
circuit_params_inds)
def _param_to_st_scan(statetensor_in, params):
return param_to_st(params, statetensor_in), None
def _hwa(left, angles):
angles = angles.reshape(layers, -1)
angles = 2 * angles
res, _ = jax.lax.scan(_param_to_st_scan, left, angles)
return res
return _hwa
def _multi_cnot_and_measure(n_qubits):
circuit_gates = ['CX'] * n_qubits
circuit_qubit_inds = [[n_qubits - i, 0] for i in range(n_qubits)]
circuit_params_inds = [[]] * n_qubits
param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
circuit_qubit_inds,
circuit_params_inds)
def _mcn(left):
left = jax.numpy.stack((left, jax.numpy.zeros_like(left)))
left = param_to_st(left)
probabilities = left.reshape(2, -1)
probabilities = jax.numpy.square(jax.numpy.abs(probabilities))
probabilities = probabilities.sum(axis=1)
return probabilities
return _mcn
def _reduced_hea(n_qubits, layers):
circuit_gates, circuit_qubit_inds, circuit_params_inds = [], [], []
circuit_qubit_inds += [[n_qubits - 1 - i] for i in range(n_qubits)]
circuit_params_inds += [[i + len(circuit_gates)] for i in range(n_qubits)]
circuit_gates += ['Rz'] * n_qubits
circuit_gates += ['CX'] * (n_qubits - 1)
circuit_qubit_inds += [[n_qubits - 1 - i, n_qubits - 2 - i]
for i in range(n_qubits - 1)]
circuit_params_inds += [[]] * (n_qubits - 1)
param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
circuit_qubit_inds,
circuit_params_inds)
def _param_to_st_scan(statetensor_in, params):
return param_to_st(params, statetensor_in), None
def _hwa(left, angles):
angles = angles.reshape(layers, -1)
angles = 2 * angles
res, _ = jax.lax.scan(_param_to_st_scan, left, angles)
return res
return _hwa
#Untested
def _qaoa(n_qubits, layers):
circuit_gates, circuit_qubit_inds, circuit_params_inds = [], [], []
circuit_gates += ['Rxx'] * (n_qubits - 1)
circuit_qubit_inds += [[n_qubits - 1 - i, n_qubits - 2 - i]
for i in range(n_qubits - 1)]
circuit_params_inds += [[i] for i in range(n_qubits)]
circuit_qubit_inds += [[n_qubits - 1 - i] for i in range(n_qubits)]
circuit_params_inds += [[i + len(circuit_gates)] for i in range(n_qubits)]
circuit_gates += ['Rz'] * n_qubits
param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
circuit_qubit_inds,
circuit_params_inds)
def _param_to_st_scan(statetensor_in, params):
return param_to_st(params, statetensor_in), None
def _qaoa_circuit(left, angles):
angles = angles.reshape(layers, -1)
angles = 2 * angles
res, _ = jax.lax.scan(_param_to_st_scan, left, angles)
return res
return _qaoa_circuit