Skip to content

Commit 4ab4e07

Browse files
committed
add: hmc, phi4, gaussian, roll
1 parent 6be4e1c commit 4ab4e07

20 files changed

+208
-0
lines changed

.gitignore

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.cache
2+
**/*.*~
3+
**/*.pyc
4+
**/.DS_Store
5+
**/test.py
6+
**/*.testSave
7+
docs/_build
8+
data

flow/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .flow import Flow

flow/flow.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import torch
3+
from torch import nn
4+
5+
class Flow(nn.Module):
6+
7+
def __init__(self, prior = None,name = "Flow"):
8+
super(Flow, self).__init__()
9+
self.name = name
10+
self.prior = prior
11+
12+
def __call__(self,*args,**kargs):
13+
return self.sample(*args,**kargs)
14+
15+
def sample(self,batchSize):
16+
raise NotImplementedError(str(type(self)))
17+
18+
def inference(self,x):
19+
raise NotImplementedError(str(type(self)))
20+
21+
def generate(self,z):
22+
raise NotImplementedError(str(type(self)))
23+
24+
def logProbability(self,x):
25+
z,logp = self.inference(x)
26+
if self.prior is not None:
27+
return self.prior.logProbability(z)+logp
28+
return logp
29+
30+
def save(self,saveDict):
31+
saveDict[self.name] = self.state_dict()
32+
return saveDict
33+
34+
def load(self,saveDict):
35+
self.load_state_dict(saveDict)
36+
return saveDict
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

flow/rnvp/rnvp.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
from torch import nn
3+
4+
from model import Flow
5+
6+
class RNVP(Flow):
7+
def __init__(self, maskList, tList, sList, prior = None, name = "RNVP"):
8+
super(RNVP,self).__init__(prior,name)
9+
self.maskList = nn.Parameter(maskList)
10+
self.tList = tList
11+
self.sList = sList

model/rnvp/rnvp.py

Whitespace-only changes.

model/template.py

Whitespace-only changes.

source/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .gaussian import Gaussian
2+
from .phi4 import Phi4

source/gaussian.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
import numpy as np
3+
4+
from .source import Source
5+
6+
class Gaussian(Source):
7+
def __init__(self, nvars, sigma = 1, name="gaussian", requiresGrad = False):
8+
super(Gaussian,self).__init__(nvars,name)
9+
self.sigma = torch.nn.Parameter(torch.tensor([sigma],dtype=torch.float32),requires_grad=requiresGrad)
10+
11+
def sample(self, batchSize):
12+
size = [batchSize] + self.nvars
13+
return torch.randn(size).to(self.sigma)
14+
15+
def energy(self, z):
16+
return -(-0.5 * (z/self.sigma)**2-0.5*torch.log(2.*np.pi*self.sigma**2)).view(z.shape[0],-1).sum(dim=1)

source/phi4.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
import torch
3+
4+
from .source import Source
5+
from utils import HMC
6+
from utils import roll
7+
8+
class Phi4(Source):
9+
def __init__(self,l,dims,kappa,lamb,name = None):
10+
if name is None:
11+
self.name = "phi4_l"+str(l)+"_d"+str(dims)+"_kappa"+str(kappa)+"_lamb"+str(lamb)
12+
else:
13+
self.name = name
14+
self.kappa = kappa
15+
self.lamb = lamb
16+
self.dims = dims
17+
nvars = []
18+
for _ in range(dims):
19+
nvars += [l]
20+
super(Phi4,self).__init__(nvars,name)
21+
22+
def sample(self, batchSize, thermalSteps = 50, interSteps=5, epsilon=0.1):
23+
inital = torch.randn([batchSize]+self.nvars,requires_grad=True)
24+
inital = HMC(self.energy,inital,thermalSteps,interSteps,epsilon)
25+
return inital.detach()
26+
27+
def energy(self,x):
28+
S = 0
29+
for i in range(self.dims):
30+
S += x*roll(x,[1],[i+1])
31+
#S += x*roll(x,[-1],[i+1])
32+
term1 = x**2
33+
term2 = (term1-1)**2
34+
for _ in range(self.dims):
35+
S = S.sum(-1)
36+
term1 = term1.sum(-1)
37+
term2 = term2.sum(-1)
38+
S *= -2*self.kappa
39+
term2 *= self.lamb
40+
S += term1 + term2
41+
return -S

source/source.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import torch
3+
from torch import nn
4+
5+
class Source(nn.Module):
6+
7+
def __init__(self, nvars,name = "Flow"):
8+
super(Source, self).__init__()
9+
self.name = name
10+
self.nvars = nvars
11+
12+
def __call__(self,*args,**kargs):
13+
return self.sample(*args,**kargs)
14+
15+
def sample(self,batchSize):
16+
raise NotImplementedError(str(type(self)))
17+
18+
def logProbability(self,x):
19+
return -self.energy(x)
20+
21+
def energy(self,x):
22+
raise NotImplementedError(str(type(self)))
23+
24+
def save(self,saveDict):
25+
saveDict[self.name] = self.state_dict()
26+
return saveDict
27+
28+
def load(self,saveDict):
29+
self.load_state_dict(saveDict)
30+
return saveDict

utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .mc import HMCsampler, HMC
2+
from .roll import roll
File renamed without changes.

utils/mc/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .hmc import HMCsampler, HMC

utils/mc/hmc.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
import torch
3+
from torch import nn
4+
from torch.autograd import grad as torchgrad
5+
6+
torch.manual_seed(42)
7+
8+
def HMC(energy,x,length,steps,epsilon):
9+
def grad(z):
10+
return torchgrad(energy(z),z,grad_outputs=torch.ones(z.shape[0]))[0]
11+
12+
E = energy(x)
13+
g = grad(x)
14+
15+
for l in range(length):
16+
p = x.new_empty(size=x.size()).normal_()
17+
H = (0.5*p*p).view(p.shape[0], -1).sum(dim=1) + E
18+
xnew = x
19+
gnew = g
20+
for _ in range(steps):
21+
p = p- epsilon* gnew/2.
22+
xnew = xnew + epsilon * p
23+
gnew = grad(xnew)
24+
p = p- epsilon* gnew/2.
25+
Enew = energy(xnew)
26+
Hnew = (0.5*p*p).view(p.shape[0], -1).sum(dim=1) + Enew
27+
diff = H-Hnew
28+
accept = (diff.exp() >= diff.uniform_()).to(x)
29+
30+
E = accept*Enew + (1.-accept)*E
31+
accept = accept.view(x.shape[0], 1, 1)
32+
x = accept*xnew + (1.-accept)*x
33+
g = accept*gnew + (1.-accept)*g
34+
35+
return x
36+
37+
38+
class HMCsampler(nn.Module):
39+
def __init__(self,energy,nvars, epsilon=0.01, interSteps=10 , thermalSteps = 10):
40+
super(HMCsampler,self).__init__()
41+
self.nvars = nvars
42+
self.energy = energy
43+
self.interSteps = interSteps
44+
self.inital = HMC(self.energy,torch.randn(nvars),thermalSteps,interSteps)
45+
46+
def step(self):
47+
return HMC(self.energy,self.inital,1,interSteps,epsilon)

utils/roll.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
3+
def roll(x, step, axis):
4+
shape = x.shape
5+
for i,s in enumerate(step):
6+
if s >=0:
7+
x1 = x.narrow(axis[i],0,s)
8+
x2 = x.narrow(axis[i],s,shape[axis[i]]-s)
9+
else:
10+
x2 = x.narrow(axis[i],shape[axis[i]]+s,-s)
11+
x1 = x.narrow(axis[i],0,shape[axis[i]]+s)
12+
x = torch.cat([x2,x1],axis[i])
13+
return x

0 commit comments

Comments
 (0)