Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to split model components to different backend devices #461

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ struct SDParams {
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;

int model_backend_index = -1;
int clip_backend_index = -1;
int vae_backend_index = -1;
};

void print_params(SDParams params) {
Expand Down Expand Up @@ -164,6 +168,9 @@ void print_params(SDParams params) {
printf(" batch_count: %d\n", params.batch_count);
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
printf(" upscale_repeats: %d\n", params.upscale_repeats);
printf(" model_backend_index %d\n", params.model_backend_index);
printf(" clip_backend_index %d\n", params.clip_backend_index);
printf(" vae_backend_index %d\n", params.vae_backend_index);
}

void print_usage(int argc, const char* argv[]) {
Expand Down Expand Up @@ -219,6 +226,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color Colors the logging tags according to level\n");
printf(" -v, --verbose print extra info\n");
printf(" --model-backend-index specify which device the model defaults to using\n");
printf(" --clip-backend-index specify which device the CLIP model uses\n");
printf(" --vae-backend-index specify which device the VAE model uses\n");
}

void parse_args(int argc, const char** argv, SDParams& params) {
Expand Down Expand Up @@ -534,7 +544,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.verbose = true;
} else if (arg == "--color") {
params.color = true;
} else {
}
else if (arg == "--model-backend-index") {
params.model_backend_index = std::stoi(argv[++i]);
} else if (arg == "--clip-backend-index") {
params.clip_backend_index = std::stoi(argv[++i]);
} else if (arg == "--vae-backend-index") {
params.vae_backend_index = std::stoi(argv[++i]);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
exit(1);
Expand Down Expand Up @@ -791,7 +808,10 @@ int main(int argc, const char* argv[]) {
params.schedule,
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu);
params.vae_on_cpu,
params.model_backend_index,
params.clip_backend_index,
params.vae_backend_index);

if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
Expand Down
93 changes: 78 additions & 15 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ class StableDiffusionGGML {
if (clip_backend != backend) {
ggml_backend_free(clip_backend);
}
if (control_net_backend != backend) {
if (control_net_backend != backend && control_net_backend != clip_backend) {
ggml_backend_free(control_net_backend);
}
if (vae_backend != backend) {
if (vae_backend != backend && vae_backend != clip_backend && vae_backend != control_net_backend) {
ggml_backend_free(vae_backend);
}
ggml_backend_free(backend);
Expand All @@ -153,11 +153,15 @@ class StableDiffusionGGML {
schedule_t schedule,
bool clip_on_cpu,
bool control_net_cpu,
bool vae_on_cpu) {
bool vae_on_cpu,
int model_backend_index,
int clip_backend_index,
int vae_backend_index) {
use_tiny_autoencoder = taesd_path.size() > 0;
#ifdef SD_USE_CUBLAS
LOG_DEBUG("Using CUDA backend");
backend = ggml_backend_cuda_init(0);
if (model_backend_index == -1) model_backend_index = 0;
backend = ggml_backend_cuda_init(model_backend_index);
#endif
#ifdef SD_USE_METAL
LOG_DEBUG("Using Metal backend");
Expand All @@ -166,16 +170,21 @@ class StableDiffusionGGML {
#endif
#ifdef SD_USE_VULKAN
LOG_DEBUG("Using Vulkan backend");
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
backend = ggml_backend_vk_init(device);
}
if (!backend) {
LOG_WARN("Failed to initialize Vulkan backend");
}
if (model_backend_index == -1) {
// default behavior, use last device selected
int device = ggml_backend_vk_get_device_count() - 1;
backend = ggml_backend_vk_init(device);
if (!backend) {
LOG_WARN("Failed to initialize Vulkan backend");
}
} else {
backend = ggml_backend_vk_init(model_backend_index);
}
#endif
#ifdef SD_USE_SYCL
LOG_DEBUG("Using SYCL backend");
backend = ggml_backend_sycl_init(0);
if (model_backend_index == -1) model_backend_index = 0;
backend = ggml_backend_sycl_init(model_backend_index);
#endif

if (!backend) {
Expand Down Expand Up @@ -321,7 +330,29 @@ class StableDiffusionGGML {
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init();
}
} else if (clip_backend_index > -1 && clip_backend_index != model_backend_index) {
#ifdef SD_USE_CUBLAS
LOG_DEBUG("CLIP: Using CUDA backend");
clip_backend = ggml_backend_cuda_init(clip_backend_index);
#endif
#ifdef SD_USE_VULKAN
LOG_DEBUG("CLIP: Using Vulkan backend");
clip_backend = ggml_backend_vk_init(clip_backend_index);
#endif
#ifdef SD_USE_METAL
LOG_DEBUG("CLIP: Using Metal backend");
// should be the same
clip_backend = backend;
#endif
#ifdef SD_USE_SYCL
LOG_DEBUG("CLIP: Using SYCL backend");
clip_backend = ggml_backend_sycl_init(clip_backend_index);
#endif
if (!clip_backend) {
LOG_WARN("No backend device found for CLIP, defaulting to model device.");
clip_backend = backend;
}
}
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
Expand All @@ -342,7 +373,33 @@ class StableDiffusionGGML {
if (vae_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("VAE Autoencoder: Using CPU backend");
vae_backend = ggml_backend_cpu_init();
} else {
} else if (vae_backend_index == clip_backend_index) {
vae_backend = clip_backend;
} else if (vae_backend_index == model_backend_index) {
vae_backend = backend;
} else if (vae_backend_index > -1) {
#ifdef SD_USE_CUBLAS
LOG_DEBUG("VAE Autoencoder: Using CUDA backend");
vae_backend = ggml_backend_cuda_init(vae_backend_index);
#endif
#ifdef SD_USE_VULKAN
LOG_DEBUG("VAE Autoencoder: Using Vulkan backend");
vae_backend = ggml_backend_vk_init(vae_backend_index);
#endif
#ifdef SD_USE_METAL
LOG_DEBUG("CLIP: Using Metal backend");
// should be the same
vae_backend = backend;
#endif
#ifdef SD_USE_SYCL
LOG_DEBUG("VAE Autoencoder: Using SYCL backend");
vae_backend = ggml_backend_sycl_init(vae_backend_index);
#endif
if (!vae_backend) {
LOG_WARN("No backend device found for VAE, defaulting to model device.");
vae_backend = backend;
}
} else {
vae_backend = backend;
}
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend, vae_wtype, vae_decode_only, false, version);
Expand Down Expand Up @@ -1035,7 +1092,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
enum schedule_t s,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu) {
bool keep_vae_on_cpu,
int model_backend_index,
int clip_backend_index,
int vae_backend_index) {
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
if (sd_ctx == NULL) {
return NULL;
Expand Down Expand Up @@ -1076,7 +1136,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
s,
keep_clip_on_cpu,
keep_control_net_cpu,
keep_vae_on_cpu)) {
keep_vae_on_cpu,
model_backend_index,
clip_backend_index,
vae_backend_index)) {
delete sd_ctx->sd;
sd_ctx->sd = NULL;
free(sd_ctx);
Expand Down
5 changes: 4 additions & 1 deletion stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
enum schedule_t s,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu);
bool keep_vae_on_cpu,
int model_backend_index,
int clip_backend_index,
int vae_backend_index);

SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);

Expand Down