Skip to content

Commit

Permalink
Skips tests using TF2-trained model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 399579153
  • Loading branch information
jiyongjung authored and tfx-copybara committed Sep 29, 2021
1 parent 59ea5be commit 0c06171
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tfx/components/infra_validator/request_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ def _GetKerasModelSignature(self):
return request_builder._parse_saved_model_signatures(
model_path, tag_set={'serve'}, signature_names=['serving_default'])

@unittest.skipIf(
tf.__version__ < '2',
'The test uses testdata only compatible with TF2.')
def testBuildRequests_EstimatorModel_ServingDefault(self):
builder = request_builder._TFServingRpcRequestBuilder(
model_name='foo',
Expand All @@ -284,6 +287,9 @@ def testBuildRequests_EstimatorModel_ServingDefault(self):
self.assertEqual(result[0].model_spec.name, 'foo')
self.assertEqual(result[0].model_spec.signature_name, 'serving_default')

@unittest.skipIf(
tf.__version__ < '2',
'The test uses testdata only compatible with TF2.')
def testBuildRequests_EstimatorModel_Classification(self):
builder = request_builder._TFServingRpcRequestBuilder(
model_name='foo',
Expand All @@ -298,6 +304,9 @@ def testBuildRequests_EstimatorModel_Classification(self):
self.assertEqual(result[0].model_spec.name, 'foo')
self.assertEqual(result[0].model_spec.signature_name, 'classification')

@unittest.skipIf(
tf.__version__ < '2',
'The test uses testdata only compatible with TF2.')
def testBuildRequests_EstimatorModel_Regression(self):
builder = request_builder._TFServingRpcRequestBuilder(
model_name='foo',
Expand All @@ -312,6 +321,9 @@ def testBuildRequests_EstimatorModel_Regression(self):
self.assertEqual(result[0].model_spec.name, 'foo')
self.assertEqual(result[0].model_spec.signature_name, 'regression')

@unittest.skipIf(
tf.__version__ < '2',
'The test uses testdata only compatible with TF2.')
def testBuildRequests_EstimatorModel_Predict(self):
builder = request_builder._TFServingRpcRequestBuilder(
model_name='foo',
Expand Down

0 comments on commit 0c06171

Please sign in to comment.