Skip to content

Commit

Permalink
March updates (#25)
Browse files Browse the repository at this point in the history
* gradient_bias bug fix.

* Configurable patch size.

* Support for different image file types.

* Buffer size argument.
  • Loading branch information
lahavlipson authored Mar 21, 2023
1 parent aa0c656 commit 4f2f0cc
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 8 deletions.
3 changes: 3 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def run(cfg, network, imagedir, calib, stride=1, skip=0, viz=False, timeit=False
parser.add_argument('--calib', type=str)
parser.add_argument('--stride', type=int, default=2)
parser.add_argument('--skip', type=int, default=0)
parser.add_argument('--buffer', type=int, default=2048)
parser.add_argument('--config', default="config/default.yaml")
parser.add_argument('--timeit', action='store_true')
parser.add_argument('--viz', action="store_true")
Expand All @@ -83,6 +84,7 @@ def run(cfg, network, imagedir, calib, stride=1, skip=0, viz=False, timeit=False
args = parser.parse_args()

cfg.merge_from_file(args.config)
cfg.BUFFER_SIZE = args.buffer

print("Running with config...")
print(cfg)
Expand All @@ -93,6 +95,7 @@ def run(cfg, network, imagedir, calib, stride=1, skip=0, viz=False, timeit=False
if args.save_reconstruction:
pred_traj, ply_data = pred_traj
ply_data.write(f"{name}.ply")
print(f"Saved {name}.ply")

if args.save_trajectory:
Path("saved_trajectories").mkdir(exist_ok=True)
Expand Down
3 changes: 3 additions & 0 deletions dpvo/dpvo.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ def __edges_back(self):
def __call__(self, tstamp, image, intrinsics):
""" track new frame """

if (self.n+1) >= self.N:
raise Exception(f'The buffer size is too small. You can increase it using "--buffer {self.N*2}"')

if self.viewer is not None:
self.viewer.update_image(image)

Expand Down
13 changes: 7 additions & 6 deletions dpvo/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def forward(self, images, patches_per_image=80, disps=None, gradient_bias=False,
imap = self.inet(images) / 4.0

b, n, c, h, w = fmap.shape
P = self.patch_size

# bias patch selection towards regions with high gradient
if gradient_bias:
Expand All @@ -121,19 +122,19 @@ def forward(self, images, patches_per_image=80, disps=None, gradient_bias=False,
y = torch.randint(1, h-1, size=[n, 3*patches_per_image], device="cuda")

coords = torch.stack([x, y], dim=-1).float()
g = altcorr.patchify(g, coords, 0).view(-1)
g = altcorr.patchify(g[0,:,None], coords, 0).view(n, 3 * patches_per_image)

ix = torch.argsort(g)
x = x[:, ix[-patches_per_image:]]
y = y[:, ix[-patches_per_image:]]
ix = torch.argsort(g, dim=1)
x = torch.gather(x, 1, ix[:, -patches_per_image:])
y = torch.gather(y, 1, ix[:, -patches_per_image:])

else:
x = torch.randint(1, w-1, size=[n, patches_per_image], device="cuda")
y = torch.randint(1, h-1, size=[n, patches_per_image], device="cuda")

coords = torch.stack([x, y], dim=-1).float()
imap = altcorr.patchify(imap[0], coords, 0).view(b, -1, DIM, 1, 1)
gmap = altcorr.patchify(fmap[0], coords, 1).view(b, -1, 128, 3, 3)
gmap = altcorr.patchify(fmap[0], coords, P//2).view(b, -1, 128, P, P)

if return_color:
clr = altcorr.patchify(images[0], 4*(coords + 0.5), 0).view(b, -1, 3)
Expand All @@ -142,7 +143,7 @@ def forward(self, images, patches_per_image=80, disps=None, gradient_bias=False,
disps = torch.ones(b, n, h, w, device="cuda")

grid, _ = coords_grid_with_index(disps, device=fmap.device)
patches = altcorr.patchify(grid[0], coords, 1).view(b, -1, 3, 3, 3)
patches = altcorr.patchify(grid[0], coords, P//2).view(b, -1, 3, P, P)

index = torch.arange(n, device="cuda").view(n, 1)
index = index.repeat(1, patches_per_image).reshape(-1)
Expand Down
4 changes: 3 additions & 1 deletion dpvo/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from multiprocessing import Process, Queue
from pathlib import Path
from itertools import chain

def image_stream(queue, imagedir, calib, stride, skip=0):
""" image generator """
Expand All @@ -16,7 +17,8 @@ def image_stream(queue, imagedir, calib, stride, skip=0):
K[1,1] = fy
K[1,2] = cy

image_list = sorted(Path(imagedir).glob('*.png'))[skip::stride]
img_exts = ["*.png", "*.jpeg", "*.jpg"]
image_list = sorted(chain.from_iterable(Path(imagedir).glob(e) for e in img_exts))[skip::stride]

for t, imfile in enumerate(image_list):
image = cv2.imread(str(imfile))
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def train(args):
loss = 0.0
for i, (v, x, y, P1, P2, kl) in enumerate(traj):
e = (x - y).norm(dim=-1)
e = e.reshape(-1, 9)[(v > 0.5).reshape(-1)].min(dim=-1).values
e = e.reshape(-1, net.P**2)[(v > 0.5).reshape(-1)].min(dim=-1).values

N = P1.shape[1]
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
Expand Down

0 comments on commit 4f2f0cc

Please sign in to comment.