Skip to content

Commit

Permalink
update metric3d ViT.giant2 model
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianxff committed May 3, 2024
1 parent 4c48e87 commit 70d7d04
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 22 deletions.
13 changes: 5 additions & 8 deletions depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Running Metric3D')
parser.add_argument("--images", help='dir for image files', type=str, required=True)
parser.add_argument("--focal", help='focal length', type=float, default=None)
parser.add_argument("--calib", help='calibration file, overwrite focal', type=str, default=None)
parser.add_argument("--calib", help='calibration file, overwrite focal', type=str, required=True)
parser.add_argument("--out", help='dir for output depth', type=str, default='')
parser.add_argument("--ckpt", type=str, default='./weights/metric_depth_vit_large_800k.pth', help='checkpoint file')
parser.add_argument("--model-name", type=str, default='v2-L', choices=['v2-L', 'v2-S'], help='model type')
parser.add_argument("--ckpt", type=str, default='./weights/metric_depth_vit_giant2_800k.pth', help='checkpoint file')
parser.add_argument("--model-name", type=str, default='v2-g', choices=['v2-L', 'v2-S', 'v2-g'], help='model type')
args = parser.parse_args()

if args.calib:
calib = np.loadtxt(args.calib)
args.focal = (abs(calib[0]) + abs(calib[1])) / 2
calib = np.loadtxt(args.calib)

metric = Metric(
checkpoint=args.ckpt,
Expand All @@ -38,7 +35,7 @@
os.makedirs(color_dir, exist_ok=True)

for image in tqdm(images):
depth = metric(image, args.focal)
depth = metric(rgb_image=image, calib=calib)

# save orignal depth
np.save(str(out_dir / f'{image.stem}.npy'), depth)
Expand Down
17 changes: 9 additions & 8 deletions modules/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class Metric3D:

def __init__(
self,
checkpoint: Union[str, Path] = './weights/metric_depth_vit_giant2_800k.pth',
model_name: str = 'v2-g',
checkpoint: Union[str, Path] = './weights/metric_depth_vit_large_800k.pth',
model_name: str = 'v2-L',
) -> None:
checkpoint = Path(checkpoint).resolve()
cfg:Config = self._load_config_(model_name, checkpoint)
Expand All @@ -44,7 +44,7 @@ def __init__(
def __call__(
self,
rgb_image: Union[np.ndarray, Image.Image, str, Path],
focal: Optional[float] = None,
calib: Union[str, Path, np.ndarray]
) -> np.ndarray:
# read image
if isinstance(rgb_image, (str, Path)):
Expand All @@ -53,9 +53,9 @@ def __call__(
rgb_image = np.array(rgb_image)
# get intrinsic
h, w = rgb_image.shape[:2]
if focal is None:
focal = np.max([h, w])
intrinsic = [focal, focal, w/2, h/2]
if isinstance(calib, (str, Path)):
calib = np.loadtxt(calib)
intrinsic = calib[:4]
# transform image
rgb_input, cam_models_stacks, pad, label_scale_factor = \
transform_test_data_scalecano(rgb_image, intrinsic, self.cfg_.data_basic)
Expand All @@ -82,11 +82,12 @@ def _load_config_(
checkpoint: Union[str, Path],
) -> Config:
config_path = metric3d_path / 'mono/configs/HourglassDecoder'
assert model_name in ['v2-L', 'v2-S'], f"Model {model_name} not supported"
assert model_name in ['v2-L', 'v2-S', 'v2-g'], f"Model {model_name} not supported"
# load config file
cfg = Config.fromfile(
str(config_path / 'vit.raft5.large.py') if model_name == 'v2-L'
else str(config_path / 'vit.raft5.small.py')
else str(config_path / 'vit.raft5.small.py') if model_name == 'v2-S'
else str(config_path / 'vit.raft5.giant2.py')
)
cfg.load_from = str(checkpoint)
# load data info
Expand Down
5 changes: 2 additions & 3 deletions scripts/test_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
parser.add_argument("--global-ba-frontend", type=int, help="frequency to run global ba on frontend", default=0)

parser.add_argument("--metric-ckpt", type=str, help='checkpoint file', default='./weights/metric_depth_vit_large_800k.pth')
parser.add_argument("--metric-model", type=str, help='model type', default='v2-L', choices=['v2-L', 'v2-S'])
parser.add_argument("--metric-model", type=str, help='model type', default='v2-L', choices=['v2-L', 'v2-S', 'v2-g'])
parser.add_argument("--droid-ckpt", type=str, help="checkpoint file", default='./weights/droid.pth')

args = parser.parse_args()
Expand Down Expand Up @@ -57,7 +57,6 @@
calib = np.loadtxt(calib_file)
intr = calib[:4]
distort = calib[4:] if len(calib) > 4 else None
fxy = (intr[0] + intr[1]) / 2

# metric 3d ###############################################################
metric = Metric(
Expand All @@ -70,7 +69,7 @@
if os.path.exists(str(depth_dir / f'{image.stem}.npy')) and not args.overwrite:
continue

depth = metric(image, fxy)
depth = metric(image, intr)
# save orignal depth
np.save(str(depth_dir / f'{image.stem}.npy'), depth)
# save colormap
Expand Down
5 changes: 2 additions & 3 deletions scripts/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
parser.add_argument("--global-ba-frontend", type=int, help="frequency to run global ba on frontend", default=0)

parser.add_argument("--metric-ckpt", type=str, help='checkpoint file', default='./weights/metric_depth_vit_large_800k.pth')
parser.add_argument("--metric-model", type=str, help='model type', default='v2-L', choices=['v2-L', 'v2-S'])
parser.add_argument("--metric-model", type=str, help='model type', default='v2-L', choices=['v2-L', 'v2-S', 'v2-g'])
parser.add_argument("--droid-ckpt", type=str, help="checkpoint file", default='./weights/droid.pth')

args = parser.parse_args()
Expand Down Expand Up @@ -62,7 +62,6 @@
calib = np.loadtxt(str(calib_file))
intr = calib[:4]
distort = calib[4:] if len(calib) > 4 else None
fxy = (intr[0] + intr[1]) / 2

# video sample ###############################################################
sample_from_video(
Expand All @@ -83,7 +82,7 @@
if os.path.exists(str(depth_dir / f'{image.stem}.npy')) and not args.overwrite:
continue

depth = metric(image, fxy)
depth = metric(image, intr)
# save orignal depth
np.save(str(depth_dir / f'{image.stem}.npy'), depth)
# save colormap
Expand Down

0 comments on commit 70d7d04

Please sign in to comment.