Skip to content

Commit

Permalink
More tests covering auto-batching Fibonacci.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 257690405
  • Loading branch information
axch authored and tensorflower-gardener committed Jul 11, 2019
1 parent 9e3470a commit 91bfb8a
Showing 1 changed file with 32 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,50 @@

# Eensy weensy test function
def fibonacci(n):
if n > 2:
return fibonacci(n - 1) + fibonacci(n - 2)
else:
if n <= 1:
return 1
else:
left = fibonacci(n - 2)
right = fibonacci(n - 1)
return left + right


@test_util.run_all_in_graph_and_eager_modes
class AutoGraphFrontendTest(tf.test.TestCase):

def testFibonacci(self):
self.assertEqual(1, fibonacci(0))
self.assertEqual(1, fibonacci(1))
self.assertEqual(2, fibonacci(2))
self.assertEqual(3, fibonacci(3))
self.assertEqual(5, fibonacci(4))
self.assertEqual(8, fibonacci(5))
self.assertEqual(13, fibonacci(6))
self.assertEqual(21, fibonacci(7))
self.assertEqual(34, fibonacci(8))
self.assertEqual(55, fibonacci(9))

def testFibonacciNumpy(self):
batch_fibo = frontend.Context().batch_uncurried(
fibonacci,
lambda *args: instructions.TensorType(np.int64, ()))
self.assertEqual(
[8, 13, 21],
list(batch_fibo(np.array([6, 7, 8], dtype=np.int64),
[13, 21, 34, 55],
list(batch_fibo(np.array([6, 7, 8, 9], dtype=np.int64),
max_stack_depth=15, backend=NP_BACKEND)))

def testFibonacciNumpyStackless(self):
if not tf.executing_eagerly():
return
batch_fibo = frontend.Context().batch_uncurried(
fibonacci,
lambda *args: instructions.TensorType(np.int64, ()))
self.assertEqual(
[3, 21, 5, 8],
list(batch_fibo(np.array([3, 7, 4, 5], dtype=np.int64),
max_stack_depth=15, backend=NP_BACKEND,
stackless=True)))

def testEvenOddWithContext(self):
def pred_type(_):
return instructions.TensorType(np.int32, ())
Expand Down Expand Up @@ -502,7 +528,7 @@ def testFibonacciTF(self):
input_2 = self._build_tensor(np.array([6, 7, 8], dtype=np.int64))
answer = batch_fibo(input_2, max_stack_depth=15, backend=TF_BACKEND)
self._check_batch_size(answer, 3)
self.assertAllEqual([8, 13, 21], self.evaluate(answer))
self.assertAllEqual([13, 21, 34], self.evaluate(answer))

def testOneArmedAndNestedIf(self):
def int_type(_):
Expand Down

0 comments on commit 91bfb8a

Please sign in to comment.