Skip to content

Commit

Permalink
cog using common load fxn
Browse files Browse the repository at this point in the history
  • Loading branch information
daanelson authored and hmartiro committed Jan 5, 2023
1 parent 19fea12 commit 2695af9
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 35 deletions.
2 changes: 1 addition & 1 deletion cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ build:
- "flask_cors==3.0.10"
- "flask==1.1.2"
- "numpy==1.19.4"
- "pillow==8.2.0"
- "pillow==9.1.0"
- "pydub==0.25.1"
- "scipy==1.6.3"
- "torch==1.13.0"
Expand Down
2 changes: 1 addition & 1 deletion integrations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This package contains integrations of Riffusion into third party apps and deploy
To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and
download the model weights:

cog run python -m riffusion.cog_riffusion --download_weights
cog run python -m integrations.cog_riffusion --download_weights

Then you can run predictions:

Expand Down
43 changes: 10 additions & 33 deletions integrations/cog_riffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import PIL
import torch
from cog import BaseModel, BasePredictor, Input, Path
from huggingface_hub import hf_hub_download

from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.riffusion_pipeline import RiffusionPipeline
Expand All @@ -21,7 +20,6 @@

MODEL_ID = "riffusion/riffusion-model-v1"
MODEL_CACHE = "riffusion-cache"
UNET_CACHE = "unet-cache"

# Where built-in seed images are stored
SEED_IMAGES_DIR = Path("./seed_images")
Expand All @@ -46,7 +44,7 @@ class RiffusionPredictor(BasePredictor):
See README & https://github.com/replicate/cog for details
"""

def setup(self):
def setup(self, local_files_only=True):
"""
Loads the model onto GPU from local cache.
"""
Expand All @@ -56,7 +54,8 @@ def setup(self):
checkpoint=MODEL_ID,
use_traced_unet=True,
device=self.device,
local_files_only=True,
local_files_only=local_files_only,
cache_dir=MODEL_CACHE,
)

def predict(
Expand Down Expand Up @@ -137,38 +136,16 @@ def predict(
# RiffusionPipeline.load_checkpoint?


def download_weights(checkpoint: str):
def download_weights():
"""
Clears local cache & downloads riffusion weights
"""
for folder in [MODEL_CACHE, UNET_CACHE]:
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)

model, unet_file = _load_model(checkpoint, local_only=False)
return model, unet_file


def _load_model(checkpoint: str, local_only=False):
model = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
cache_dir=MODEL_CACHE,
local_files_only=local_only,
)
if os.path.exists(MODEL_CACHE):
shutil.rmtree(MODEL_CACHE)
os.makedirs(MODEL_CACHE)

unet_file = hf_hub_download(
"riffusion/riffusion-model-v1",
filename="unet_traced.pt",
subfolder="unet_traced",
cache_dir=UNET_CACHE,
local_files_only=local_only,
)
return model, unet_file
pred = RiffusionPredictor()
pred.setup(local_files_only=False)


if __name__ == "__main__":
Expand All @@ -178,4 +155,4 @@ def _load_model(checkpoint: str, local_only=False):
)
args = parser.parse_args()
if args.download_weights:
download_weights(MODEL_ID)
download_weights()
8 changes: 8 additions & 0 deletions riffusion/riffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def load_checkpoint(
device: str = "cuda",
local_files_only: bool = False,
low_cpu_mem_usage: bool = False,
cache_dir: T.Optional[str] = None,
) -> RiffusionPipeline:
"""
Load the riffusion model pipeline.
Expand Down Expand Up @@ -97,6 +98,7 @@ def load_checkpoint(
safety_checker=lambda images, **kwargs: (images, False),
low_cpu_mem_usage=low_cpu_mem_usage,
local_files_only=local_files_only,
cache_dir=cache_dir,
).to(device)

if channels_last:
Expand All @@ -111,6 +113,8 @@ def load_checkpoint(
in_channels=pipeline.unet.in_channels,
dtype=dtype,
device=device,
local_files_only=local_files_only,
cache_dir=cache_dir,
)

if traced_unet is not None:
Expand All @@ -128,6 +132,8 @@ def load_traced_unet(
in_channels: int,
dtype: torch.dtype,
device: str = "cuda",
local_files_only=False,
cache_dir: T.Optional[str] = None,
) -> T.Optional[torch.nn.Module]:
"""
Load a traced unet from the huggingface hub. This can improve performance.
Expand All @@ -141,6 +147,8 @@ def load_traced_unet(
checkpoint,
subfolder=subfolder,
filename=filename,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
unet_traced = torch.jit.load(unet_file)

Expand Down

0 comments on commit 2695af9

Please sign in to comment.