|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# Copyright (C) 2024 Hao Zhang<[email protected]> |
| 4 | +# |
| 5 | +# This program is free software: you can redistribute it and/or modify |
| 6 | +# it under the terms of the GNU General Public License as published by |
| 7 | +# the Free Software Foundation, either version 3 of the License, or |
| 8 | +# any later version. |
| 9 | +# |
| 10 | +# This program is distributed in the hope that it will be useful, |
| 11 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 12 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 13 | +# GNU General Public License for more details. |
| 14 | +# |
| 15 | +# You should have received a copy of the GNU General Public License |
| 16 | +# along with this program. If not, see <https://www.gnu.org/licenses/>. |
| 17 | +# |
| 18 | + |
| 19 | +import torch |
| 20 | + |
| 21 | + |
| 22 | +class FakeLinear(torch.nn.Module): |
| 23 | + |
| 24 | + def __init__(self, dim_in, dim_out): |
| 25 | + super().__init__() |
| 26 | + self.bias = torch.nn.Parameter(torch.zeros([dim_out])) |
| 27 | + |
| 28 | + def forward(self, x): |
| 29 | + shape = x.shape[:-1] |
| 30 | + prod = torch.tensor(shape).prod() |
| 31 | + return self.bias.view([1, -1]).expand([prod, -1]).view([*shape, -1]) |
| 32 | + |
| 33 | + |
| 34 | +def Linear(dim_in, dim_out): |
| 35 | + if dim_in == 0: |
| 36 | + return FakeLinear(dim_in, dim_out) |
| 37 | + else: |
| 38 | + return torch.nn.Linear(dim_in, dim_out) |
| 39 | + |
| 40 | + |
| 41 | +class MLP(torch.nn.Module): |
| 42 | + |
| 43 | + def __init__(self, dim_input, dim_output, hidden_size): |
| 44 | + super().__init__() |
| 45 | + self.dim_input = dim_input |
| 46 | + self.dim_output = dim_output |
| 47 | + self.hidden_size = hidden_size |
| 48 | + self.depth = len(hidden_size) |
| 49 | + |
| 50 | + self.model = torch.nn.Sequential(*(Linear( |
| 51 | + dim_input if i == 0 else hidden_size[i - 1], |
| 52 | + dim_output if i == self.depth else hidden_size[i], |
| 53 | + ) if j == 0 else torch.nn.SiLU() for i in range(self.depth + 1) for j in range(2) if i != self.depth or j != 1)) |
| 54 | + |
| 55 | + def forward(self, x): |
| 56 | + return self.model(x) |
| 57 | + |
| 58 | + |
| 59 | +class WaveFunction(torch.nn.Module): |
| 60 | + |
| 61 | + def __init__( |
| 62 | + self, |
| 63 | + *, |
| 64 | + L1, |
| 65 | + L2, |
| 66 | + orbit_num, |
| 67 | + physical_dim, |
| 68 | + is_complex, |
| 69 | + spin_up, |
| 70 | + spin_down, |
| 71 | + hidden_size, |
| 72 | + ordering, |
| 73 | + ): |
| 74 | + super().__init__() |
| 75 | + self.L1 = L1 |
| 76 | + self.L2 = L2 |
| 77 | + self.orbit_num = orbit_num |
| 78 | + self.sites = L1 * L2 * orbit_num // 2 |
| 79 | + assert physical_dim == 2 |
| 80 | + assert is_complex == True |
| 81 | + self.spin_up = spin_up |
| 82 | + self.spin_down = spin_down |
| 83 | + self.hidden_size = tuple(hidden_size) |
| 84 | + |
| 85 | + self.amplitude = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)]) |
| 86 | + self.phase = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)]) |
| 87 | + |
| 88 | + if isinstance(ordering, int) and ordering == +1: |
| 89 | + ordering = list(range(self.sites)) |
| 90 | + if isinstance(ordering, int) and ordering == -1: |
| 91 | + ordering = list(reversed(range(self.sites))) |
| 92 | + self.register_buffer('ordering', torch.tensor(ordering, dtype=torch.int64), persistent=True) |
| 93 | + ordering_bak = torch.zeros(self.sites, dtype=torch.int64) |
| 94 | + ordering_bak.scatter_(0, self.ordering, torch.arange(self.sites)) |
| 95 | + self.register_buffer('ordering_bak', ordering_bak, persistent=True) |
| 96 | + |
| 97 | + def mask(self, x): |
| 98 | + # x : batch * i * 2 |
| 99 | + i = x.size(1) |
| 100 | + # number : batch * 2 |
| 101 | + number = x.sum(dim=1) |
| 102 | + |
| 103 | + up_electron = number[:, 0] |
| 104 | + down_electron = number[:, 1] |
| 105 | + up_hole = i - up_electron |
| 106 | + down_hole = i - down_electron |
| 107 | + |
| 108 | + add_up_electron = up_electron < self.spin_up |
| 109 | + add_down_electron = down_electron < self.spin_down |
| 110 | + add_up_hole = up_hole < self.sites - self.spin_up |
| 111 | + add_down_hole = down_hole < self.sites - self.spin_down |
| 112 | + |
| 113 | + add_up = torch.stack([add_up_hole, add_up_electron], dim=-1).unsqueeze(-1) |
| 114 | + add_down = torch.stack([add_down_hole, add_down_electron], dim=-1).unsqueeze(-2) |
| 115 | + add = torch.logical_and(add_up, add_down) |
| 116 | + return add |
| 117 | + |
| 118 | + def normalize_amplitude(self, x): |
| 119 | + param = -(2 * x).exp().sum(dim=[1, 2]).log() / 2 |
| 120 | + x = x + param.unsqueeze(-1).unsqueeze(-1) |
| 121 | + return x |
| 122 | + |
| 123 | + def forward(self, x): |
| 124 | + device = next(self.parameters()).device |
| 125 | + dtype = next(self.parameters()).dtype |
| 126 | + |
| 127 | + batch_size = x.size(0) |
| 128 | + x = x.reshape([batch_size, self.sites, 2]) |
| 129 | + x = torch.index_select(x, 1, self.ordering_bak) |
| 130 | + |
| 131 | + xf = x.to(dtype=dtype) |
| 132 | + arange = torch.arange(batch_size, device=device) |
| 133 | + total_amplitude = 0 |
| 134 | + total_phase = 0 |
| 135 | + for i in range(self.sites): |
| 136 | + amplitude = self.amplitude[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2]) |
| 137 | + phase = self.phase[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2]) |
| 138 | + amplitude = amplitude + torch.where(self.mask(x[:, :i]), 0, -torch.inf) |
| 139 | + amplitude = self.normalize_amplitude(amplitude) |
| 140 | + amplitude = amplitude[arange, x[:, i, 0], x[:, i, 1]] |
| 141 | + phase = phase[arange, x[:, i, 0], x[:, i, 1]] |
| 142 | + total_amplitude = total_amplitude + amplitude |
| 143 | + total_phase = total_phase + phase |
| 144 | + return (total_amplitude + 1j * total_phase).exp() |
| 145 | + |
| 146 | + def binomial(self, count, possibility): |
| 147 | + possibility = torch.clamp(possibility, min=0, max=1) |
| 148 | + possibility = torch.where(count == 0, 0, possibility) |
| 149 | + dist = torch.distributions.binomial.Binomial(count, possibility) |
| 150 | + result = dist.sample() |
| 151 | + result = result.to(dtype=torch.int64) |
| 152 | + # Numerical error since result was cast to float. |
| 153 | + return torch.clamp(result, min=torch.zeros_like(count), max=count) |
| 154 | + |
| 155 | + def generate(self, batch_size, alpha=1): |
| 156 | + # https://arxiv.org/pdf/2109.12606 |
| 157 | + device = next(self.parameters()).device |
| 158 | + dtype = next(self.parameters()).dtype |
| 159 | + assert alpha == 1 |
| 160 | + |
| 161 | + x = torch.empty([1, 0, 2], device=device, dtype=torch.int64) |
| 162 | + multiplicity = torch.tensor([batch_size], dtype=torch.int64, device=device) |
| 163 | + amplitude_phase = torch.tensor([0], dtype=dtype.to_complex(), device=device) |
| 164 | + for i in range(self.sites): |
| 165 | + local_batch_size = x.size(0) |
| 166 | + |
| 167 | + xf = x.to(dtype=dtype) |
| 168 | + amplitude = self.amplitude[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2]) |
| 169 | + phase = self.phase[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2]) |
| 170 | + amplitude = amplitude + torch.where(self.mask(x), 0, -torch.inf) |
| 171 | + amplitude = self.normalize_amplitude(amplitude) |
| 172 | + delta_amplitude_phase = (amplitude + 1j * phase).reshape([local_batch_size, 4]) |
| 173 | + probability = (2 * amplitude).exp().reshape([local_batch_size, 4]) |
| 174 | + probability = probability / probability.sum(dim=-1).unsqueeze(-1) |
| 175 | + |
| 176 | + sample0123 = multiplicity |
| 177 | + prob23 = probability[:, 2] + probability[:, 3] |
| 178 | + prob01 = probability[:, 0] + probability[:, 1] |
| 179 | + sample23 = self.binomial(sample0123, prob23) |
| 180 | + sample3 = self.binomial(sample23, probability[:, 3] / prob23) |
| 181 | + sample2 = sample23 - sample3 |
| 182 | + sample01 = sample0123 - sample23 |
| 183 | + sample1 = self.binomial(sample01, probability[:, 1] / prob01) |
| 184 | + sample0 = sample01 - sample1 |
| 185 | + |
| 186 | + x0 = torch.cat([x, torch.tensor([[0, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1) |
| 187 | + x1 = torch.cat([x, torch.tensor([[0, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1) |
| 188 | + x2 = torch.cat([x, torch.tensor([[1, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1) |
| 189 | + x3 = torch.cat([x, torch.tensor([[1, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1) |
| 190 | + |
| 191 | + new_x = torch.cat([x0, x1, x2, x3]) |
| 192 | + new_multiplicity = torch.cat([sample0, sample1, sample2, sample3]) |
| 193 | + new_amplitude_phase = (amplitude_phase.unsqueeze(0) + delta_amplitude_phase.permute(1, 0)).reshape([-1]) |
| 194 | + |
| 195 | + selected = new_multiplicity != 0 |
| 196 | + x = new_x[selected] |
| 197 | + multiplicity = new_multiplicity[selected] |
| 198 | + amplitude_phase = new_amplitude_phase[selected] |
| 199 | + |
| 200 | + real_amplitude = amplitude_phase.exp() |
| 201 | + real_probability = (real_amplitude.conj() * real_amplitude).real |
| 202 | + x = torch.index_select(x, 1, self.ordering) |
| 203 | + return x.reshape([x.size(0), self.L1, self.L2, self.orbit_num]), real_amplitude, torch.ones_like(real_probability), torch.ones_like(multiplicity) |
| 204 | + |
| 205 | + |
| 206 | +class ReweightWaveFunction(torch.nn.Module): |
| 207 | + |
| 208 | + def __init__( |
| 209 | + self, |
| 210 | + *args, |
| 211 | + **kwargs, |
| 212 | + ): |
| 213 | + super().__init__() |
| 214 | + self.psi = WaveFunction(*args, **kwargs) |
| 215 | + self._es = WaveFunction(*args, **kwargs).cuda(), |
| 216 | + self.es.load_state_dict(self.psi.state_dict()) |
| 217 | + self.es.cuda() |
| 218 | + |
| 219 | + @property |
| 220 | + def es(self): |
| 221 | + return self._es[0] |
| 222 | + |
| 223 | + def forward(self, x): |
| 224 | + return self.psi(x) |
| 225 | + |
| 226 | + def generate(self, batch_size, alpha=1): |
| 227 | + configurations, _, weights, multiplicities = self.es.generate(batch_size, alpha) |
| 228 | + amplitudes = self(configurations) |
| 229 | + return configurations, amplitudes, weights, multiplicities |
| 230 | + |
| 231 | + |
| 232 | +def network(state, spin_up, spin_down, hidden_size, ordering=+1): |
| 233 | + max_orbit_index = max(orbit for [l1, l2, orbit], edge in state.physics_edges) |
| 234 | + max_physical_dim = max(edge.dimension for [l1, l2, orbit], edge in state.physics_edges) |
| 235 | + network = ReweightWaveFunction( |
| 236 | + L1=state.L1, |
| 237 | + L2=state.L2, |
| 238 | + orbit_num=max_orbit_index + 1, |
| 239 | + physical_dim=max_physical_dim, |
| 240 | + is_complex=state.Tensor.is_complex, |
| 241 | + spin_up=spin_up, |
| 242 | + spin_down=spin_down, |
| 243 | + hidden_size=hidden_size, |
| 244 | + ordering=ordering, |
| 245 | + ).double() |
| 246 | + return network |
0 commit comments