Skip to content

Commit

Permalink
Correct embedder loading from pkl files
Browse files Browse the repository at this point in the history
Summary: Point 3 of D25393460

Reviewed By: vkhalidov

Differential Revision: D25616216

fbshipit-source-id: 666879c81bb809ffdfcb9e2926d6f1111c5741c2
  • Loading branch information
MarcSzafraniec authored and facebook-github-bot committed Dec 18, 2020
1 parent b7fbaa1 commit c614036
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion projects/DensePose/densepose/modeling/cse/embedder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import numpy as np
import pickle
from enum import Enum
from typing import Optional
Expand Down Expand Up @@ -80,7 +82,9 @@ def __init__(self, cfg: CfgNode):
super(Embedder, self).__init__()
self.mesh_names = set()
embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
logger = logging.getLogger(__name__)
for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items():
logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}")
self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim))
self.mesh_names.add(mesh_name)
if cfg.MODEL.WEIGHTS != "":
Expand All @@ -100,7 +104,10 @@ def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None):
state_dict_local = {}
for key in state_dict["model"]:
if key.startswith(prefix):
state_dict_local[key[len(prefix) :]] = state_dict["model"][key]
v_key = state_dict["model"][key]
if isinstance(v_key, np.ndarray):
v_key = torch.from_numpy(v_key)
state_dict_local[key[len(prefix) :]] = v_key
# non-strict loading to finetune on different meshes
self.load_state_dict(state_dict_local, strict=False)

Expand Down

0 comments on commit c614036

Please sign in to comment.