Skip to content

Commit

Permalink
switch some prints to logging
Browse files Browse the repository at this point in the history
  • Loading branch information
gkucsko committed Apr 13, 2023
1 parent 2c03817 commit 76966a8
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions bark/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def _load_model(ckpt_path, device, model_type="text"):
os.path.exists(ckpt_path) and
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
):
print(f"found outdated {model_type} model, removing...")
logger.warning(f"found outdated {model_type} model, removing...")
os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
print(f"{model_type} model not found, downloading...")
logger.info(f"{model_type} model not found, downloading...")
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
Expand Down Expand Up @@ -215,7 +215,7 @@ def _load_model(ckpt_path, device, model_type="text"):
model.load_state_dict(state_dict, strict=False)
n_params = model.get_num_params()
val_loss = checkpoint["best_val_loss"].item()
print(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
model.eval()
model.to(device)
del checkpoint, state_dict
Expand Down Expand Up @@ -346,7 +346,7 @@ def generate_text_semantic(
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
if len(encoded_text) > 256:
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
print(f"warning, text too long, lopping of last {p}%")
logger.warning(f"warning, text too long, lopping of last {p}%")
encoded_text = encoded_text[:256]
encoded_text = np.pad(
encoded_text,
Expand Down

0 comments on commit 76966a8

Please sign in to comment.