Skip to content

Commit

Permalink
fix infer
Browse files Browse the repository at this point in the history
  • Loading branch information
autumn-2-net committed Sep 20, 2023
1 parent 16bd307 commit f87d065
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
7 changes: 3 additions & 4 deletions inference/me_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ def preprocess(self, waveform: np.ndarray) -> Dict[str, torch.Tensor]:

@torch.no_grad()
def forward_model(self, sample: Dict[str, torch.Tensor]):
sig=False
if self.config['use_BCEWithLogitsLoss']:
sig=True

probs, bounds = self.model(x=sample['units'], f0=sample['pitch'], mask=sample['masks'],sig=sig)


probs, bounds = self.model(x=sample['units'], f0=sample['pitch'], mask=sample['masks'],sig=True)

return {
'probs': probs,
Expand Down
4 changes: 2 additions & 2 deletions modules/conform/Gconform.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ def forward(self, x, pitch, mask=None):
midiout = self.outln(x)
cutprp = torch.sigmoid(cutprp)
cutprp = torch.squeeze(cutprp, -1)
if self.sig:
midiout = torch.sigmoid(midiout)
# if self.sig:
# midiout = torch.sigmoid(midiout)
return midiout, cutprp
14 changes: 6 additions & 8 deletions training/me_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,8 @@ def build_model(self):
return model

def build_losses_and_metrics(self):
if not self.config['use_BCEWithLogitsLoss']:
self.midi_loss = nn.BCELoss()
else:
self.midi_loss = nn.BCEWithLogitsLoss()

self.midi_loss = nn.BCEWithLogitsLoss()
self.bound_loss = modules.losses.BinaryEMDLoss()
# self.bound_loss = modules.losses.BinaryEMDLoss(bidirectional=True)
self.register_metric('midi_acc', modules.metrics.MIDIAccuracy(tolerance=0.5))
Expand All @@ -90,16 +88,16 @@ def run_model(self, sample, infer=False):
# mask=None

f0 = sample['pitch']
sig=False
if self.config['use_BCEWithLogitsLoss'] and infer:
sig=True

probs, bounds = self.model(x=spec, f0=f0, mask=mask,sig=sig)



if infer:
probs, bounds = self.model(x=spec, f0=f0, mask=mask, sig=True)
return probs, bounds
else:
losses = {}
probs, bounds = self.model(x=spec, f0=f0, mask=mask, sig=False)

if self.cfg['use_bound_loss']:
bound_loss = self.bound_loss(bounds, sample['bounds'])
Expand Down

0 comments on commit f87d065

Please sign in to comment.