Skip to content

Commit

Permalink
default weight bug fix 2
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxkib committed Jan 30, 2024
1 parent aafc291 commit 2028b85
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 46 deletions.
31 changes: 4 additions & 27 deletions cog/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@
SAFETY_URL = "https://weights.replicate.delivery/default/playgroundai/safety-cache.tar"

SDXL_NAME_TO_PATHLIKE = {
# `stable-diffusion-xl-base-1.0` is the default model, it's speical since it's always on disk (downloaded in setup)
# These are all huggingface models that we host via gcp + pget
"stable-diffusion-xl-base-1.0": {
"slug": "stabilityai/stable-diffusion-xl-base-1.0",
"url": "https://weights.replicate.delivery/default/InstantID/models--stabilityai--stable-diffusion-xl-base-1.0.tar",
"path": "checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0",
},
# These are all huggingface models that we host via gcp + pget
"afrodite-xl-v2": {
"slug": "stablediffusionapi/afrodite-xl-v2",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--afrodite-xl-v2.tar",
Expand Down Expand Up @@ -209,18 +210,7 @@ def setup(self) -> None:
local_files_only=True,
)

self.base_weights = "stable-diffusion-xl-base-1.0"
weights_info = SDXL_NAME_TO_PATHLIKE[self.base_weights]
self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
weights_info["slug"],
controlnet=self.controlnet,
torch_dtype=torch.float16,
cache_dir=CHECKPOINTS_CACHE,
local_files_only=True,
)

self.pipe.cuda()
self.pipe.load_ip_adapter_instantid(self.face_adapter)
self.load_weights("stable-diffusion-xl-base-1.0")
self.setup_safety_checker()

def setup_safety_checker(self):
Expand Down Expand Up @@ -249,19 +239,6 @@ def load_weights(self, sdxl_weights):
self.base_weights = sdxl_weights
weights_info = SDXL_NAME_TO_PATHLIKE[self.base_weights]

if sdxl_weights == "stable-diffusion-xl-base-1.0": # Default, it's always there
self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
weights_info["slug"],
controlnet=self.controlnet,
torch_dtype=torch.float16,
cache_dir=CHECKPOINTS_CACHE,
local_files_only=True,
)
self.pipe.cuda()
self.pipe.load_ip_adapter_instantid(self.face_adapter)
self.setup_safety_checker()
return

download_url = weights_info["url"]
path_to_weights_dir = weights_info["path"]
if not os.path.exists(path_to_weights_dir):
Expand Down
41 changes: 22 additions & 19 deletions scripts/push_to_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,23 @@
dirs = [
# "checkpoints/models--stablediffusionapi--juggernaut-xl-v8",
# "checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0",
"checkpoints/models--stablediffusionapi--afrodite-xl-v2",
"checkpoints/models--stablediffusionapi--albedobase-xl-20",
"checkpoints/models--stablediffusionapi--albedobase-xl-v13",
"checkpoints/models--stablediffusionapi--animagine-xl-30",
"checkpoints/models--stablediffusionapi--anime-art-diffusion-xl",
"checkpoints/models--stablediffusionapi--anime-illust-diffusion-xl",
"checkpoints/models--stablediffusionapi--dreamshaper-xl",
"checkpoints/models--stablediffusionapi--duchaiten-real3d-nsfw-xl",
"checkpoints/models--stablediffusionapi--dynavision-xl-v0610",
"checkpoints/models--stablediffusionapi--guofeng4-xl",
"checkpoints/models--stablediffusionapi--hentai-mix-xl",
"checkpoints/models--stablediffusionapi--juggernaut-xl-v8",
"checkpoints/models--stablediffusionapi--nightvision-xl-0791",
"checkpoints/models--stablediffusionapi--omnigen-xl",
"checkpoints/models--stablediffusionapi--pony-diffusion-v6-xl",
"checkpoints/models--stablediffusionapi--protovision-xl-high-fidel",
# "checkpoints/models--stablediffusionapi--afrodite-xl-v2",
# "checkpoints/models--stablediffusionapi--albedobase-xl-20",
# "checkpoints/models--stablediffusionapi--albedobase-xl-v13",
# "checkpoints/models--stablediffusionapi--animagine-xl-30",
# "checkpoints/models--stablediffusionapi--anime-art-diffusion-xl",
# "checkpoints/models--stablediffusionapi--anime-illust-diffusion-xl",
# "checkpoints/models--stablediffusionapi--dreamshaper-xl",
# "checkpoints/models--stablediffusionapi--duchaiten-real3d-nsfw-xl",
# "checkpoints/models--stablediffusionapi--dynavision-xl-v0610",
# "checkpoints/models--stablediffusionapi--guofeng4-xl",
# "checkpoints/models--stablediffusionapi--hentai-mix-xl",
# "checkpoints/models--stablediffusionapi--juggernaut-xl-v8",
# "checkpoints/models--stablediffusionapi--nightvision-xl-0791",
# "checkpoints/models--stablediffusionapi--omnigen-xl",
# "checkpoints/models--stablediffusionapi--pony-diffusion-v6-xl",
# "checkpoints/models--stablediffusionapi--protovision-xl-high-fidel",
"checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0",
]

# Iterate over each directory
Expand All @@ -40,7 +41,9 @@
print(f"[!] Step 1: Constructing tar file name as '{tar_file_name}'")

# Construct the full path to the tar file
full_tar_path = os.path.join(".", tar_file_name)
full_tar_path = os.path.join(
"..", tar_file_name
) # Adjusted to account for script's new location
print(f"[!] Step 2: The full path for the tar file is '{full_tar_path}'")

# Remove 'checkpoints/' from tar_file_name for gcloud destination
Expand All @@ -51,8 +54,8 @@
f"[!] Step 3: The destination path on GCloud is set to '{gcloud_destination}'"
)

# Construct the shell command string
cmd = f"cd ./{d} && tar -cvf ../../{tar_file_name} * && gcloud storage cp ../../{tar_file_name} {gcloud_destination}"
# Adjust the shell command string to account for the script's new location
cmd = f"cd ../{d} && tar -cvf ../../{tar_file_name} * && gcloud storage cp ../../{tar_file_name} {gcloud_destination}"
print(
f"[!] Step 4: The shell command constructed to perform the operations is: {cmd}"
)
Expand Down

0 comments on commit 2028b85

Please sign in to comment.