Skip to content

Commit 4eb37ab

Browse files
committed
add reweight es naqs. not completed yet
when sampling, it use self.es, and when get ws, it use self.psi self.psi is updated by outside, but self.es need to be optimize too, which is not implemented yet.
1 parent ed6876e commit 4eb37ab

File tree

3 files changed

+315
-6
lines changed

3 files changed

+315
-6
lines changed

tetragono/tetragono/sampling_neural_state/observer.py

+54-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,39 @@
1818

1919
import numpy as np
2020
import torch
21-
from ..utility import allreduce_buffer, allreduce_number, show, showln
21+
from ..utility import allreduce_buffer, allreduce_number, show, showln, mpi_comm
2222
from .state import Configuration, index_tensor_element
2323

24+
opt = None
25+
26+
27+
def torch_tensor_allgather(tensor):
28+
from mpi4py import MPI
29+
# Get the device of the input tensor
30+
device = tensor.device
31+
32+
# Convert torch tensor to numpy array
33+
np_array = tensor.cpu().detach().numpy()
34+
35+
# Initialize MPI
36+
comm = mpi_comm
37+
rank = comm.Get_rank()
38+
size = comm.Get_size()
39+
40+
counts = comm.allgather(np_array.size)
41+
first = comm.allgather(np_array.shape[0])
42+
total_length = sum(first)
43+
# Create a buffer to hold all gathered numpy arrays
44+
gathered_np_arrays = np.empty((total_length, *np_array.shape[1:]), dtype=np_array.dtype)
45+
46+
# Perform allgather
47+
comm.Allgatherv(np_array, [gathered_np_arrays, counts])
48+
49+
# Convert gathered numpy arrays back to torch tensor
50+
gathered_tensor = torch.from_numpy(gathered_np_arrays).to(device)
51+
52+
return gathered_tensor
53+
2454

2555
class Observer():
2656
"""
@@ -66,8 +96,7 @@ def __enter__(self):
6696
if self._enable_gradient:
6797
self._Delta = None
6898
self._EDelta = None
69-
if self._enable_natural:
70-
self._Deltas = []
99+
self._Deltas = [] # 临时使用这个list做别的用处
71100

72101
def __exit__(self, exc_type, exc_val, exc_tb):
73102
"""
@@ -114,6 +143,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
114143
allreduce_buffer(self._Delta)
115144
allreduce_buffer(self._EDelta)
116145

146+
cs = torch.stack([c for c, e in self._Deltas])
147+
es = torch.tensor([e for c, e in self._Deltas], dtype=torch.complex128, device=cs.device)
148+
cs = torch_tensor_allgather(cs)
149+
es = torch.view_as_complex(torch_tensor_allgather(torch.view_as_real(es)))
150+
es = es - es.mean() # 总之这个是用来采样的东西,以后可能会添加别的比如Delta也乘进去
151+
with torch.enable_grad():
152+
global opt
153+
if opt is None:
154+
opt = torch.optim.Adam(self.owner.network.es.parameters(), 1e-2)
155+
for _ in range(100):
156+
hes = self.owner.network.es(cs)
157+
error = hes / hes.norm() - es / es.norm()
158+
error = (error.abs()**2).mean()
159+
show(error.item())
160+
opt.zero_grad()
161+
error.backward()
162+
opt.step()
163+
showln("es error", error.item())
164+
117165
def __init__(
118166
self,
119167
owner,
@@ -395,6 +443,9 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
395443
name].imag * reweight
396444
if name == "energy" and self._enable_gradient:
397445
Es = whole_result[batch_index][name]
446+
# train self.es
447+
# collect and optimize self.es
448+
self._Deltas.append((configurations[batch_index], Es))
398449
if self.owner.Tensor.is_real:
399450
Es = Es.real
400451

tetragono/tetragono/sampling_neural_state/state.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -348,16 +348,28 @@ def holes(self, value):
348348
if self.Tensor.is_complex:
349349
with torch_grad(True):
350350
value.real.backward(retain_graph=True)
351-
real = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad])
351+
real = torch.cat([
352+
param.grad.reshape([-1])
353+
for param in self.network.parameters()
354+
if param.requires_grad and param.grad is not None
355+
])
352356
self.network.zero_grad()
353357
with torch_grad(True):
354358
value.imag.backward()
355-
imag = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad])
359+
imag = torch.cat([
360+
param.grad.reshape([-1])
361+
for param in self.network.parameters()
362+
if param.requires_grad and param.grad is not None
363+
])
356364
self.network.zero_grad()
357365
result = (real + 1j * imag)
358366
else:
359367
value.backward()
360-
result = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad])
368+
result = torch.cat([
369+
param.grad.reshape([-1])
370+
for param in self.network.parameters()
371+
if param.requires_grad and param.grad is not None
372+
])
361373
self.network.zero_grad()
362374
result = result / value
363375
return result.detach_()
+246
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright (C) 2024 Hao Zhang<[email protected]>
4+
#
5+
# This program is free software: you can redistribute it and/or modify
6+
# it under the terms of the GNU General Public License as published by
7+
# the Free Software Foundation, either version 3 of the License, or
8+
# any later version.
9+
#
10+
# This program is distributed in the hope that it will be useful,
11+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
# GNU General Public License for more details.
14+
#
15+
# You should have received a copy of the GNU General Public License
16+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
17+
#
18+
19+
import torch
20+
21+
22+
class FakeLinear(torch.nn.Module):
23+
24+
def __init__(self, dim_in, dim_out):
25+
super().__init__()
26+
self.bias = torch.nn.Parameter(torch.zeros([dim_out]))
27+
28+
def forward(self, x):
29+
shape = x.shape[:-1]
30+
prod = torch.tensor(shape).prod()
31+
return self.bias.view([1, -1]).expand([prod, -1]).view([*shape, -1])
32+
33+
34+
def Linear(dim_in, dim_out):
35+
if dim_in == 0:
36+
return FakeLinear(dim_in, dim_out)
37+
else:
38+
return torch.nn.Linear(dim_in, dim_out)
39+
40+
41+
class MLP(torch.nn.Module):
42+
43+
def __init__(self, dim_input, dim_output, hidden_size):
44+
super().__init__()
45+
self.dim_input = dim_input
46+
self.dim_output = dim_output
47+
self.hidden_size = hidden_size
48+
self.depth = len(hidden_size)
49+
50+
self.model = torch.nn.Sequential(*(Linear(
51+
dim_input if i == 0 else hidden_size[i - 1],
52+
dim_output if i == self.depth else hidden_size[i],
53+
) if j == 0 else torch.nn.SiLU() for i in range(self.depth + 1) for j in range(2) if i != self.depth or j != 1))
54+
55+
def forward(self, x):
56+
return self.model(x)
57+
58+
59+
class WaveFunction(torch.nn.Module):
60+
61+
def __init__(
62+
self,
63+
*,
64+
L1,
65+
L2,
66+
orbit_num,
67+
physical_dim,
68+
is_complex,
69+
spin_up,
70+
spin_down,
71+
hidden_size,
72+
ordering,
73+
):
74+
super().__init__()
75+
self.L1 = L1
76+
self.L2 = L2
77+
self.orbit_num = orbit_num
78+
self.sites = L1 * L2 * orbit_num // 2
79+
assert physical_dim == 2
80+
assert is_complex == True
81+
self.spin_up = spin_up
82+
self.spin_down = spin_down
83+
self.hidden_size = tuple(hidden_size)
84+
85+
self.amplitude = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)])
86+
self.phase = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)])
87+
88+
if isinstance(ordering, int) and ordering == +1:
89+
ordering = list(range(self.sites))
90+
if isinstance(ordering, int) and ordering == -1:
91+
ordering = list(reversed(range(self.sites)))
92+
self.register_buffer('ordering', torch.tensor(ordering, dtype=torch.int64), persistent=True)
93+
ordering_bak = torch.zeros(self.sites, dtype=torch.int64)
94+
ordering_bak.scatter_(0, self.ordering, torch.arange(self.sites))
95+
self.register_buffer('ordering_bak', ordering_bak, persistent=True)
96+
97+
def mask(self, x):
98+
# x : batch * i * 2
99+
i = x.size(1)
100+
# number : batch * 2
101+
number = x.sum(dim=1)
102+
103+
up_electron = number[:, 0]
104+
down_electron = number[:, 1]
105+
up_hole = i - up_electron
106+
down_hole = i - down_electron
107+
108+
add_up_electron = up_electron < self.spin_up
109+
add_down_electron = down_electron < self.spin_down
110+
add_up_hole = up_hole < self.sites - self.spin_up
111+
add_down_hole = down_hole < self.sites - self.spin_down
112+
113+
add_up = torch.stack([add_up_hole, add_up_electron], dim=-1).unsqueeze(-1)
114+
add_down = torch.stack([add_down_hole, add_down_electron], dim=-1).unsqueeze(-2)
115+
add = torch.logical_and(add_up, add_down)
116+
return add
117+
118+
def normalize_amplitude(self, x):
119+
param = -(2 * x).exp().sum(dim=[1, 2]).log() / 2
120+
x = x + param.unsqueeze(-1).unsqueeze(-1)
121+
return x
122+
123+
def forward(self, x):
124+
device = next(self.parameters()).device
125+
dtype = next(self.parameters()).dtype
126+
127+
batch_size = x.size(0)
128+
x = x.reshape([batch_size, self.sites, 2])
129+
x = torch.index_select(x, 1, self.ordering_bak)
130+
131+
xf = x.to(dtype=dtype)
132+
arange = torch.arange(batch_size, device=device)
133+
total_amplitude = 0
134+
total_phase = 0
135+
for i in range(self.sites):
136+
amplitude = self.amplitude[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2])
137+
phase = self.phase[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2])
138+
amplitude = amplitude + torch.where(self.mask(x[:, :i]), 0, -torch.inf)
139+
amplitude = self.normalize_amplitude(amplitude)
140+
amplitude = amplitude[arange, x[:, i, 0], x[:, i, 1]]
141+
phase = phase[arange, x[:, i, 0], x[:, i, 1]]
142+
total_amplitude = total_amplitude + amplitude
143+
total_phase = total_phase + phase
144+
return (total_amplitude + 1j * total_phase).exp()
145+
146+
def binomial(self, count, possibility):
147+
possibility = torch.clamp(possibility, min=0, max=1)
148+
possibility = torch.where(count == 0, 0, possibility)
149+
dist = torch.distributions.binomial.Binomial(count, possibility)
150+
result = dist.sample()
151+
result = result.to(dtype=torch.int64)
152+
# Numerical error since result was cast to float.
153+
return torch.clamp(result, min=torch.zeros_like(count), max=count)
154+
155+
def generate(self, batch_size, alpha=1):
156+
# https://arxiv.org/pdf/2109.12606
157+
device = next(self.parameters()).device
158+
dtype = next(self.parameters()).dtype
159+
assert alpha == 1
160+
161+
x = torch.empty([1, 0, 2], device=device, dtype=torch.int64)
162+
multiplicity = torch.tensor([batch_size], dtype=torch.int64, device=device)
163+
amplitude_phase = torch.tensor([0], dtype=dtype.to_complex(), device=device)
164+
for i in range(self.sites):
165+
local_batch_size = x.size(0)
166+
167+
xf = x.to(dtype=dtype)
168+
amplitude = self.amplitude[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2])
169+
phase = self.phase[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2])
170+
amplitude = amplitude + torch.where(self.mask(x), 0, -torch.inf)
171+
amplitude = self.normalize_amplitude(amplitude)
172+
delta_amplitude_phase = (amplitude + 1j * phase).reshape([local_batch_size, 4])
173+
probability = (2 * amplitude).exp().reshape([local_batch_size, 4])
174+
probability = probability / probability.sum(dim=-1).unsqueeze(-1)
175+
176+
sample0123 = multiplicity
177+
prob23 = probability[:, 2] + probability[:, 3]
178+
prob01 = probability[:, 0] + probability[:, 1]
179+
sample23 = self.binomial(sample0123, prob23)
180+
sample3 = self.binomial(sample23, probability[:, 3] / prob23)
181+
sample2 = sample23 - sample3
182+
sample01 = sample0123 - sample23
183+
sample1 = self.binomial(sample01, probability[:, 1] / prob01)
184+
sample0 = sample01 - sample1
185+
186+
x0 = torch.cat([x, torch.tensor([[0, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1)
187+
x1 = torch.cat([x, torch.tensor([[0, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1)
188+
x2 = torch.cat([x, torch.tensor([[1, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1)
189+
x3 = torch.cat([x, torch.tensor([[1, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1)
190+
191+
new_x = torch.cat([x0, x1, x2, x3])
192+
new_multiplicity = torch.cat([sample0, sample1, sample2, sample3])
193+
new_amplitude_phase = (amplitude_phase.unsqueeze(0) + delta_amplitude_phase.permute(1, 0)).reshape([-1])
194+
195+
selected = new_multiplicity != 0
196+
x = new_x[selected]
197+
multiplicity = new_multiplicity[selected]
198+
amplitude_phase = new_amplitude_phase[selected]
199+
200+
real_amplitude = amplitude_phase.exp()
201+
real_probability = (real_amplitude.conj() * real_amplitude).real
202+
x = torch.index_select(x, 1, self.ordering)
203+
return x.reshape([x.size(0), self.L1, self.L2, self.orbit_num]), real_amplitude, torch.ones_like(real_probability), torch.ones_like(multiplicity)
204+
205+
206+
class ReweightWaveFunction(torch.nn.Module):
207+
208+
def __init__(
209+
self,
210+
*args,
211+
**kwargs,
212+
):
213+
super().__init__()
214+
self.psi = WaveFunction(*args, **kwargs)
215+
self._es = WaveFunction(*args, **kwargs).cuda(),
216+
self.es.load_state_dict(self.psi.state_dict())
217+
self.es.cuda()
218+
219+
@property
220+
def es(self):
221+
return self._es[0]
222+
223+
def forward(self, x):
224+
return self.psi(x)
225+
226+
def generate(self, batch_size, alpha=1):
227+
configurations, _, weights, multiplicities = self.es.generate(batch_size, alpha)
228+
amplitudes = self(configurations)
229+
return configurations, amplitudes, weights, multiplicities
230+
231+
232+
def network(state, spin_up, spin_down, hidden_size, ordering=+1):
233+
max_orbit_index = max(orbit for [l1, l2, orbit], edge in state.physics_edges)
234+
max_physical_dim = max(edge.dimension for [l1, l2, orbit], edge in state.physics_edges)
235+
network = ReweightWaveFunction(
236+
L1=state.L1,
237+
L2=state.L2,
238+
orbit_num=max_orbit_index + 1,
239+
physical_dim=max_physical_dim,
240+
is_complex=state.Tensor.is_complex,
241+
spin_up=spin_up,
242+
spin_down=spin_down,
243+
hidden_size=hidden_size,
244+
ordering=ordering,
245+
).double()
246+
return network

0 commit comments

Comments
 (0)