Skip to content

Commit

Permalink
nested func -> partial
Browse files Browse the repository at this point in the history
  • Loading branch information
Ending2015a committed Jul 13, 2021
1 parent e53e43e commit 9bcb02b
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tensorflow.compat.v2 as tf

import collections
import functools
import warnings

import numpy as np
Expand Down Expand Up @@ -153,11 +154,9 @@ def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]

def get_batch_input_shape(batch_size):
def _get_input_shape(dim):
shape = tf.TensorShape(dim).as_list()
return tuple([batch_size] + shape)
return _get_input_shape
def get_batch_input_shape(batch_size, dim):
shape = tf.TensorShape(dim).as_list()
return tuple([batch_size] + shape)

for cell in self.cells:
if isinstance(cell, Layer) and not cell.built:
Expand All @@ -173,7 +172,8 @@ def _get_input_shape(dim):
batch_size = tf.nest.flatten(input_shape)[0]
if tf.nest.is_nested(output_dim):
input_shape = tf.nest.map_structure(
get_batch_input_shape(batch_size), output_dim)
functools.partial(get_batch_input_shape, batch_size),
output_dim)
input_shape = tuple(input_shape)
else:
input_shape = tuple([batch_size] +
Expand Down

0 comments on commit 9bcb02b

Please sign in to comment.