Skip to content

Commit

Permalink
Training: remove trivial loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
semjon00 committed Jan 26, 2024
1 parent abaa436 commit f8ea487
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
parser.add_argument('--learning_rate', type=float, default=2e-5)
parser.add_argument('--save_time', type=int, default=60 * 60)
parser.add_argument('--baby_parameters', action='store_const', const=True, default=False)
parser.add_argument('--use_dumb_loss_function', action='store_const', const=True, default=False)
parser.add_argument('--no_random_degradation', action='store_const', const=True, default=False)
args = parser.parse_args()

Expand Down Expand Up @@ -318,8 +317,7 @@ def eval_model(model, eval_datasets):
for target_sample, dataloader in eval_datasets:
for history, fragments in iter(dataloader):
pred = model(target_sample, history, fragments, middle_repeats=middle_repeats)
lf = trivial_loss_function if args.use_dumb_loss_function else loss_function
loss: Tensor = lf(pred.float(), fragments.float()).to(dtype=precision)
loss: Tensor = loss_function(pred.float(), fragments.float()).to(dtype=precision)
if loss.isnan():
raise LossNaNException()
total_loss += loss.item()
Expand All @@ -342,11 +340,6 @@ def loss_function_freq_significance(width, device):
return loss_function_freq_significance_cache[1]


def trivial_loss_function(pred, truth):
"""Very stupid, but 100% bug-free loss function"""
return torch.mean((truth - pred) ** 2)


def loss_function(pred, truth):
"""Custom loss function, for comparing two spectrograms. Not the best one, but it should work."""
# TODO: this can be infinitely improved:
Expand Down Expand Up @@ -417,8 +410,7 @@ def train_on_bite(model: EchoMorph, optimizer: torch.optim.Optimizer, train_spec
for history, fragments in iter(dataloader):
optimizer.zero_grad()
pred = model(target_sample, history, fragments)
lf = trivial_loss_function if args.use_dumb_loss_function else loss_function
loss: Tensor = lf(pred.float(), fragments.float()).to(dtype=precision)
loss: Tensor = loss_function(pred.float(), fragments.float()).to(dtype=precision)
if loss.isnan():
raise LossNaNException()
loss.backward()
Expand Down

0 comments on commit f8ea487

Please sign in to comment.