Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bengouawu committed Oct 10, 2016
1 parent be162ba commit f74c0b5
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ result
data
*.pyc
log.txt
*.json
*.params
.idea
4 changes: 2 additions & 2 deletions center_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def update(self, labels, preds):


# define some metric of center_loss
class CenterLoss(mx.metric.EvalMetric):
class CenterLossMetric(mx.metric.EvalMetric):
def __init__(self):
super(CenterLoss, self).__init__('center_loss')
super(CenterLossMetric, self).__init__('center_loss')

def update(self, labels, preds):
self.sum_metric = + preds[1].asnumpy()[0]
Expand Down
2 changes: 1 addition & 1 deletion data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# code to automatically download dataset
mxnet_root = ''
sys.path.append(os.path.join( mxnet_root, '/tests/python/common'))
sys.path.append(os.path.join( mxnet_root, 'tests/python/common'))
import get_data
import mxnet as mx

Expand Down
4 changes: 2 additions & 2 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def fit(args, network, data_loader, data_shape, batch_end_callback=None, pattern

# custom metric
eval_metrics = mx.metric.CompositeEvalMetric()
eval_metrics.add(Accuracy)
eval_metrics.add(CenterLoss())
eval_metrics.add(Accuracy())
eval_metrics.add(CenterLossMetric())

model.fit(
X = train,
Expand Down

0 comments on commit f74c0b5

Please sign in to comment.