Skip to content

Commit

Permalink
fix globals
Browse files Browse the repository at this point in the history
  • Loading branch information
xibosun committed Nov 7, 2024
1 parent e5f1a77 commit b380552
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 35 deletions.
36 changes: 17 additions & 19 deletions demos/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,14 @@
cpu_offload = False
dtype = None

def configure_model(model_dir_path_, cpu_offload_,
dtype_, use_xdit_, ulysses_degree_, ring_degree_, use_fsdp_, t5_model_path_, max_t5_token_length_):
def configure_model(model_dir_path_, cpu_offload_, dtype_):
global model_dir_path, cpu_offload, dtype
model_dir_path = model_dir_path_
cpu_offload = cpu_offload_
dtype = dtype_
use_xdit = use_xdit_
ulysses_degree = ulysses_degree_
ring_degree = ring_degree_
use_fsdp = use_fsdp_
t5_model_path = t5_model_path_
max_t5_token_length = max_t5_token_length_

set_use_xdit(use_xdit)
set_usp_config(ulysses_degree, ring_degree)
set_use_fsdp(use_fsdp)
set_t5_model(t5_model_path)
set_max_t5_token_length(max_t5_token_length)

def load_model():

def load_model(use_fsdp, t5_model_path, max_t5_token_length,
use_xdit, ulysses_degree, ring_degree):
global num_gpus, pipeline, model_dir_path
if pipeline is None:
MOCHI_DIR = model_dir_path
Expand All @@ -67,8 +55,14 @@ def load_model():
if num_gpus > 1:
assert not cpu_offload, "CPU offload not supported in multi-GPU mode"
kwargs["world_size"] = num_gpus
kwargs["use_xdit"] = use_xdit
kwargs["ulysses_degree"] = ulysses_degree
kwargs["ring_degree"] = ring_degree
else:
kwargs["cpu_offload"] = cpu_offload
kwargs["use_fsdp"] = use_fsdp
kwargs["t5_model_path"] = t5_model_path
kwargs["max_t5_token_length"] = max_t5_token_length
kwargs["decode_type"] = "tiled_full"
pipeline = klass(**kwargs)

Expand All @@ -82,8 +76,11 @@ def generate_video(
seed,
cfg_scale,
num_inference_steps,
use_fsdp, t5_model_path, max_t5_token_length,
use_xdit, ulysses_degree, ring_degree,
):
load_model()
load_model(use_fsdp, t5_model_path, max_t5_token_length,
use_xdit, ulysses_degree, ring_degree)

# sigma_schedule should be a list of floats of length (num_inference_steps + 1),
# such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing.
Expand Down Expand Up @@ -161,8 +158,7 @@ def generate_cli(
prompt, negative_prompt, width, height, num_frames, seed,
cfg_scale, num_steps, model_dir, cpu_offload, use_xdit, ulysses_degree, ring_degree, use_fsdp, t5_model_path, max_t5_token_length
):
configure_model(model_dir, cpu_offload, torch.bfloat16, use_xdit, ulysses_degree,
ring_degree, use_fsdp, t5_model_path, max_t5_token_length)
configure_model(model_dir, cpu_offload, torch.bfloat16)
output = generate_video(
prompt,
negative_prompt,
Expand All @@ -172,6 +168,8 @@ def generate_cli(
seed,
cfg_scale,
num_steps,
use_fsdp, t5_model_path, max_t5_token_length,
use_xdit, ulysses_degree, ring_degree,
)
click.echo(f"Video generated at: {output}")

Expand Down
63 changes: 47 additions & 16 deletions src/genmo/mochi_preview/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from genmo.mochi_preview.dit.joint_model import is_use_xdit
from genmo.mochi_preview.dit.joint_model import get_usp_config
from genmo.mochi_preview.dit.joint_model.globals import T5_MODEL, MAX_T5_TOKEN_LENGTH, is_use_fsdp
from genmo.mochi_preview.dit.joint_model.globals import set_t5_model, set_max_t5_token_length, set_use_fsdp, set_use_xdit, set_usp_config


def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None:
Expand Down Expand Up @@ -412,7 +414,14 @@ def __init__(
cpu_offload: Optional[bool] = False,
decode_type: str = "full",
decode_args: Optional[Dict[str, Any]] = None,
use_fsdp,
t5_model_path,
max_t5_token_length,
):
set_use_fsdp(use_fsdp)
set_t5_model(t5_model_path)
set_max_t5_token_length(max_t5_token_length)

self.device = torch.device("cuda:0")
self.tokenizer = t5_tokenizer()
t = Timer()
Expand Down Expand Up @@ -486,13 +495,25 @@ def __init__(
world_size,
decode_type,
decode_args,
use_fsdp,
t5_model_path,
max_t5_token_length,
use_xdit,
ulysses_degree,
ring_degree,
):
set_use_fsdp(use_fsdp)
set_t5_model(t5_model_path)
set_max_t5_token_length(max_t5_token_length)
set_use_xdit(use_xdit)
set_usp_config(ulysses_degree, ring_degree)

t = Timer()
self.device = torch.device(f"cuda:{device_id}")
print(f"Initializing rank {local_rank+1}/{world_size}")
assert world_size > 1, f"Multi-GPU mode requires world_size > 1, got {world_size}"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["MASTER_PORT"] = "29501"
with t("init_process_group"):
dist.init_process_group(
"nccl",
Expand All @@ -504,18 +525,10 @@ def __init__(
cp.set_cp_group(pg, list(range(world_size)), local_rank)
distributed_kwargs = dict(local_rank=local_rank, device_id=device_id, world_size=world_size)
self.world_size = world_size
self.tokenizer = t5_tokenizer()
with t("load_text_encoder"):
self.text_encoder = text_encoder_factory.get_model(**distributed_kwargs)
with t("load_dit"):
self.dit = dit_factory.get_model(**distributed_kwargs)
with t("load_vae"):
self.decoder = decoder_factory.get_model(**distributed_kwargs)
self.local_rank = local_rank
self.decode_type = decode_type
self.decode_args = decode_args or {}
t.print_stats()


# TODO(jiaruifang) confuse local_rank and rank, not applied to multi-node
if is_use_xdit():
cp_rank, cp_size = cp.get_cp_rank_size()
Expand All @@ -525,6 +538,7 @@ def __init__(
)

ulysses_degree, ring_degree = get_usp_config()
init_distributed_environment(rank=cp_rank, world_size=cp_size)
if ulysses_degree is None and ring_degree is None:
print(f"No usp config, use default config: ulysses_degree={cp_size}, ring_degree=1")
initialize_model_parallel(
Expand All @@ -543,14 +557,19 @@ def __init__(
ring_degree=ring_degree,
ulysses_degree=ulysses_degree,
)
init_distributed_environment(rank=cp_rank, world_size=cp_size)

initialize_model_parallel(
sequence_parallel_degree=ulysses_degree,
ring_degree=ring_degree,
ulysses_degree=cp_size,
)
print(f"initialized model parallel with sequence_parallel_degree={cp_size}, ring_degree=1, ulysses_degree={cp_size}")

self.tokenizer = t5_tokenizer()
with t("load_text_encoder"):
self.text_encoder = text_encoder_factory.get_model(**distributed_kwargs)
with t("load_dit"):
self.dit = dit_factory.get_model(**distributed_kwargs)
with t("load_vae"):
self.decoder = decoder_factory.get_model(**distributed_kwargs)

t.print_stats()

def run(self, *, fn, **kwargs):
return fn(self, **kwargs)

Expand All @@ -565,6 +584,12 @@ def __init__(
world_size: int,
decode_type: str = "full",
decode_args: Optional[Dict[str, Any]] = None,
use_fsdp,
t5_model_path,
max_t5_token_length,
use_xdit,
ulysses_degree,
ring_degree,
):
ray.init()
RemoteClass = ray.remote(MultiGPUContext)
Expand All @@ -578,6 +603,12 @@ def __init__(
local_rank=i,
decode_type=decode_type,
decode_args=decode_args,
use_fsdp=use_fsdp,
t5_model_path=t5_model_path,
max_t5_token_length=max_t5_token_length,
use_xdit=use_xdit,
ulysses_degree=ulysses_degree,
ring_degree=ring_degree,
)
for i in range(world_size)
]
Expand Down

0 comments on commit b380552

Please sign in to comment.