Skip to content

Commit

Permalink
feat: 优化日志打印
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Sep 3, 2023
1 parent 5a6f824 commit 391dea8
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
9 changes: 5 additions & 4 deletions attentions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import copy
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

import commons
import modules
from torch.nn.utils import weight_norm, remove_weight_norm
import logging

logger = logging.getLogger(__name__)

class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
Expand Down Expand Up @@ -55,7 +56,7 @@ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_s
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
# vits2 says 3rd block, so idx is 2 by default
self.cond_layer_idx = kwargs['cond_layer_idx'] if 'cond_layer_idx' in kwargs else 2
print(self.gin_channels, self.cond_layer_idx)
logging.debug(self.gin_channels, self.cond_layer_idx)
assert self.cond_layer_idx < self.n_layers, 'cond_layer_idx should be less than n_layers'
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
Expand Down
6 changes: 2 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

MATPLOTLIB_FLAG = False

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
logger = logging.getLogger(__name__)


def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
Expand Down Expand Up @@ -42,13 +41,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
except:
print("error, %s is not in the checkpoint" % k)
logger.error("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
print("load ")
logger.info("Loaded checkpoint '{}' (iteration {})".format(
checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
Expand Down
18 changes: 16 additions & 2 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import logging

logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")

logger = logging.getLogger(__name__)

import torch
import argparse
import commons
Expand Down Expand Up @@ -67,8 +78,12 @@ def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
parser.add_argument("-m", "--model", default="./logs/as/G_8000.pth", help="path of your model")
parser.add_argument("-c", "--config", default="./configs/config.json", help="path of your config file")
parser.add_argument("--share", default=False, help="make link public")
parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")

args = parser.parse_args()
if args.debug:
logger.info("Enable DEBUG-LEVEL log")
logging.basicConfig(level=logging.DEBUG)
hps = utils.get_hparams_from_file(args.config)

device = (
Expand All @@ -92,8 +107,7 @@ def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):

speaker_ids = hps.data.spk2id
speakers = list(speaker_ids.keys())
app = gr.Blocks()
with app:
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
text = gr.TextArea(label="Text", placeholder="Input Text Here",
Expand Down

0 comments on commit 391dea8

Please sign in to comment.