1
1
import tensorflow as tf
2
- from ..layers .utils import combined_dnn_input
3
2
4
- def input_fn_pandas (df , features , label = None , batch_size = 256 , num_epochs = 1 , shuffle = False , queue_capacity = 2560 ,
3
+
4
+ def input_fn_pandas (df , features , label = None , batch_size = 256 , num_epochs = 1 , shuffle = False , queue_capacity_factor = 10 ,
5
5
num_threads = 1 ):
6
- """
7
-
8
- :param df:
9
- :param features:
10
- :param label:
11
- :param batch_size:
12
- :param num_epochs:
13
- :param shuffle:
14
- :param queue_capacity:
15
- :param num_threads:
16
- :return:
17
- """
18
6
if label is not None :
19
7
y = df [label ]
20
8
else :
21
9
y = None
22
10
if tf .__version__ >= "2.0.0" :
23
11
return tf .compat .v1 .estimator .inputs .pandas_input_fn (df [features ], y , batch_size = batch_size ,
24
12
num_epochs = num_epochs ,
25
- shuffle = shuffle , queue_capacity = queue_capacity ,
13
+ shuffle = shuffle ,
14
+ queue_capacity = batch_size * queue_capacity_factor ,
26
15
num_threads = num_threads )
27
16
28
17
return tf .estimator .inputs .pandas_input_fn (df [features ], y , batch_size = batch_size , num_epochs = num_epochs ,
29
- shuffle = shuffle , queue_capacity = queue_capacity , num_threads = num_threads )
18
+ shuffle = shuffle , queue_capacity = batch_size * queue_capacity_factor ,
19
+ num_threads = num_threads )
30
20
31
21
32
- def input_fn_tfrecord (filenames , feature_description , label = None , batch_size = 256 , num_epochs = 1 , shuffle = False ,
33
- num_parallel_calls = 10 ):
22
+ def input_fn_tfrecord (filenames , feature_description , label = None , batch_size = 256 , num_epochs = 1 , num_parallel_calls = 8 ,
23
+ shuffle_factor = 10 , prefetch_factor = 1 ,
24
+ ):
34
25
def _parse_examples (serial_exmp ):
35
26
features = tf .parse_single_example (serial_exmp , features = feature_description )
36
27
if label is not None :
@@ -40,16 +31,17 @@ def _parse_examples(serial_exmp):
40
31
41
32
def input_fn ():
42
33
dataset = tf .data .TFRecordDataset (filenames )
43
- dataset = dataset .map (_parse_examples , num_parallel_calls = num_parallel_calls ).prefetch (
44
- buffer_size = batch_size * 10 )
45
- if shuffle :
46
- dataset = dataset .shuffle (buffer_size = batch_size * 10 )
34
+ dataset = dataset .map (_parse_examples , num_parallel_calls = num_parallel_calls )
35
+ if shuffle_factor > 0 :
36
+ dataset = dataset .shuffle (buffer_size = batch_size * shuffle_factor )
47
37
48
38
dataset = dataset .repeat (num_epochs ).batch (batch_size )
39
+
40
+ if prefetch_factor > 0 :
41
+ dataset = dataset .prefetch (buffer_size = batch_size * prefetch_factor )
42
+
49
43
iterator = dataset .make_one_shot_iterator ()
50
44
51
45
return iterator .get_next ()
52
46
53
47
return input_fn
54
-
55
-
0 commit comments