-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathdroid.py
95 lines (68 loc) · 3.11 KB
/
droid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import lietorch
import numpy as np
from .droid_net import DroidNet
from .depth_video import DepthVideo
from .motion_filter import MotionFilter
from .droid_frontend import DroidFrontend
from .droid_backend import DroidBackend
from .trajectory_filler import PoseTrajectoryFiller
from collections import OrderedDict
from torch.multiprocessing import Process
class Droid:
def __init__(self, args):
super(Droid, self).__init__()
self.load_weights(args.weights)
self.args = args
self.disable_vis = args.disable_vis
# store images, depth, poses, intrinsics (shared between processes)
self.video = DepthVideo(args.image_size, args.buffer, stereo=args.stereo)
# filter incoming frames so that there is enough motion
self.filterx = MotionFilter(self.net, self.video, thresh=args.filter_thresh)
# frontend process
self.frontend = DroidFrontend(self.net, self.video, self.args)
# backend process
self.backend = DroidBackend(self.net, self.video, self.args)
# visualizer
if not self.disable_vis:
from .visualization import droid_visualization
self.visualizer = Process(target=droid_visualization, args=(self.video, 'cuda:0', self.args.vis_save))
self.visualizer.start()
# post processor - fill in poses for non-keyframes
self.traj_filler = PoseTrajectoryFiller(self.net, self.video)
def load_weights(self, weights):
""" load trained model weights """
print(weights)
self.net = DroidNet()
state_dict = OrderedDict([
(k.replace("module.", ""), v) for (k, v) in torch.load(weights).items()])
state_dict["update.weight.2.weight"] = state_dict["update.weight.2.weight"][:2]
state_dict["update.weight.2.bias"] = state_dict["update.weight.2.bias"][:2]
state_dict["update.delta.2.weight"] = state_dict["update.delta.2.weight"][:2]
state_dict["update.delta.2.bias"] = state_dict["update.delta.2.bias"][:2]
self.net.load_state_dict(state_dict)
self.net.to("cuda:0").eval()
def track(self, tstamp, image, depth=None, intrinsics=None):
""" main thread - update map """
with torch.no_grad():
# check there is enough motion
self.filterx.track(tstamp, image, depth, intrinsics)
# local bundle adjustment
self.frontend()
# global bundle adjustment
# self.backend()
def terminate(self, stream=None):
""" terminate the visualization process, return timestamp and poses [t, q] """
del self.frontend
torch.cuda.empty_cache()
print("#" * 32)
self.backend(7)
torch.cuda.empty_cache()
print("#" * 32)
self.backend(20)
camera_trajectory = self.traj_filler(stream)
camera_trajectory = camera_trajectory.inv().data.cpu().numpy()
# fill timestamp
timestamps = np.arange(len(camera_trajectory)).reshape(-1, 1)
traj_tum = np.concatenate([timestamps, camera_trajectory], axis=1)
return traj_tum