Skip to content

Commit

Permalink
Merge pull request huggingface#76 from patil-suraj/zoe-nk
Browse files Browse the repository at this point in the history
Add support for zoedepth_nk
patrickvonplaten authored Sep 4, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents c8ee2a7 + 27e64dc commit f1dc84f
Showing 2 changed files with 16 additions and 12 deletions.
13 changes: 9 additions & 4 deletions src/controlnet_aux/zoe/__init__.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@

from ..util import HWC3, resize_image
from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth
from .zoedepth.models.zoedepth_nk.zoedepth_nk_v1 import ZoeDepthNK
from .zoedepth.utils.config import get_config


@@ -17,16 +18,17 @@ def __init__(self, model):
self.model = model

@classmethod
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None):
def from_pretrained(cls, pretrained_model_or_path, model_type="zoedepth", filename=None, cache_dir=None):
filename = filename or "ZoeD_M12_N.pt"

if os.path.isdir(pretrained_model_or_path):
model_path = os.path.join(pretrained_model_or_path, filename)
else:
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)

conf = get_config("zoedepth", "infer")
model = ZoeDepth.build_from_config(conf)
conf = get_config(model_type, "infer")
model_cls = ZoeDepth if model_type == "zoedepth" else ZoeDepthNK
model = model_cls.build_from_config(conf)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
model.eval()

@@ -36,7 +38,7 @@ def to(self, device):
self.model.to(device)
return self

def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type=None):
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type=None, gamma_corrected=False):
device = next(iter(self.model.parameters())).device
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
@@ -63,6 +65,9 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, out
depth -= vmin
depth /= vmax - vmin
depth = 1.0 - depth

if gamma_corrected:
depth = np.power(depth, 2.2)
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)

detected_map = depth_image
Original file line number Diff line number Diff line change
@@ -27,15 +27,14 @@
import torch
import torch.nn as nn

from zoedepth.models.depth_model import DepthModel
from zoedepth.models.base_models.midas import MidasCore
from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed
from zoedepth.models.layers.dist_layers import ConditionalLogBinomial
from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor,
from ..depth_model import DepthModel
from ..base_models.midas import MidasCore
from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed
from ..layers.dist_layers import ConditionalLogBinomial
from ..layers.localbins_layers import (Projector, SeedBinRegressor,
SeedBinRegressorUnnormed)
from zoedepth.models.layers.patch_transformer import PatchTransformerEncoder
from zoedepth.models.model_io import load_state_from_resource

from ..layers.patch_transformer import PatchTransformerEncoder
from ..model_io import load_state_from_resource

class ZoeDepthNK(DepthModel):
def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128,

0 comments on commit f1dc84f

Please sign in to comment.