Skip to content

Commit

Permalink
Move NCF estimator to R1.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 303897691
  • Loading branch information
saberkun authored and tensorflower-gardener committed Mar 31, 2020
1 parent ad34b62 commit 6d7030f
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 19 deletions.
1 change: 1 addition & 0 deletions official/r1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ in the previous releases.
| ----- | ----------- | --------- |
| [Gradient Boosted Trees](boosted_trees) | A gradient boosted trees model to classify higgs boson process from HIGGS dataset | [Link](https://en.wikipedia.org/wiki/Gradient_boosting) |
| [MNIST](mnist) | A basic model to classify digits from the MNIST dataset | [Link](http://yann.lecun.com/exdb/mnist/) |
| [NCF](ncf) | NCF Estimator implementation | [arXiv:1708.05031](https://arxiv.org/abs/1708.05031) |
| [ResNet](resnet) | A deep residual network for image recognition | [arXiv:1512.03385](https://arxiv.org/abs/1512.03385) |
| [Transformer](transformer) | A transformer model to translate the WMT English to German dataset | [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) |
| [Wide & Deep Learning](wide_deep) | A model that combines a wide linear model and deep neural network for recommender systems | [arXiv:1606.07792](https://arxiv.org/abs/1606.07792) |
File renamed without changes.
21 changes: 2 additions & 19 deletions official/recommendation/ncf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,15 @@

import numpy as np
import tensorflow as tf

from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
from official.recommendation import constants as rconst
from official.recommendation import data_pipeline
from official.recommendation import neumf_model
from official.recommendation import ncf_common
from official.recommendation import ncf_estimator_main
from official.recommendation import ncf_keras_main
from official.recommendation import neumf_model
from official.utils.misc import keras_utils
from official.utils.testing import integration

from tensorflow.python.eager import context # pylint: disable=ungrouped-imports


NUM_TRAIN_NEG = 4

Expand Down Expand Up @@ -190,20 +187,6 @@ def test_hit_rate_and_ndcg(self):

_BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']

@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS)

@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])

@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic(
Expand Down

0 comments on commit 6d7030f

Please sign in to comment.