87
87
88
88
TEXT_FEATURES = f"{ TEXT } _features"
89
89
LABEL_FEATURES = f"{ LABEL } _features"
90
+ TEXT_MASK = f"{ TEXT } _mask"
91
+ LABEL_MASK = f"{ LABEL } _mask"
90
92
LABEL_IDS = f"{ LABEL } _ids"
91
93
TAG_IDS = "tag_ids"
92
- TEXT_SEQ_LENGTH = f"{ TEXT } _lengths"
93
- LABEL_SEQ_LENGTH = f"{ LABEL } _lengths"
94
94
95
95
96
96
class DIETClassifier (IntentClassifier , EntityExtractor ):
@@ -484,7 +484,7 @@ def _create_label_data(
484
484
# to track correctly dynamic sequences
485
485
label_data .add_features (LABEL_IDS , [np .expand_dims (label_ids , - 1 )])
486
486
487
- label_data .add_lengths ( LABEL_SEQ_LENGTH , LABEL_FEATURES )
487
+ label_data .add_mask ( LABEL_MASK , LABEL_FEATURES )
488
488
489
489
return label_data
490
490
@@ -558,8 +558,8 @@ def _create_model_data(
558
558
model_data .add_features (LABEL_IDS , [np .expand_dims (label_ids , - 1 )])
559
559
model_data .add_features (TAG_IDS , [tag_ids ])
560
560
561
- model_data .add_lengths ( TEXT_SEQ_LENGTH , TEXT_FEATURES )
562
- model_data .add_lengths ( LABEL_SEQ_LENGTH , LABEL_FEATURES )
561
+ model_data .add_mask ( TEXT_MASK , TEXT_FEATURES )
562
+ model_data .add_mask ( LABEL_MASK , LABEL_FEATURES )
563
563
564
564
return model_data
565
565
@@ -1165,6 +1165,10 @@ def _prepare_entity_recognition_layers(self) -> None:
1165
1165
average = "micro" ,
1166
1166
)
1167
1167
1168
+ @staticmethod
1169
+ def _get_sequence_lengths (mask : tf .Tensor ) -> tf .Tensor :
1170
+ return tf .cast (tf .reduce_sum (mask [:, :, 0 ], axis = 1 ), tf .int32 )
1171
+
1168
1172
def _combine_sparse_dense_features (
1169
1173
self ,
1170
1174
features : List [Union [np .ndarray , tf .Tensor , tf .SparseTensor ]],
@@ -1246,23 +1250,16 @@ def _create_sequence(
1246
1250
outputs = self ._tf_layers [f"{ name } _transformer" ](
1247
1251
transformer_inputs , 1 - mask , self ._training
1248
1252
)
1249
-
1250
- if self .config [NUM_TRANSFORMER_LAYERS ] > 0 :
1251
- # apply activation
1252
- outputs = tfa .activations .gelu (outputs )
1253
+ outputs = tfa .activations .gelu (outputs )
1253
1254
1254
1255
return outputs , inputs , seq_ids , lm_mask_bool
1255
1256
1256
1257
def _create_all_labels (self ) -> Tuple [tf .Tensor , tf .Tensor ]:
1257
1258
all_label_ids = self .tf_label_data [LABEL_IDS ][0 ]
1258
-
1259
- label_lengths = self .sequence_lengths_for (
1260
- self .tf_label_data [LABEL_SEQ_LENGTH ][0 ]
1261
- )
1262
- mask_label = self ._compute_mask (label_lengths )
1263
-
1264
1259
x = self ._create_bow (
1265
- self .tf_label_data [LABEL_FEATURES ], mask_label , self .label_name ,
1260
+ self .tf_label_data [LABEL_FEATURES ],
1261
+ self .tf_label_data [LABEL_MASK ][0 ],
1262
+ self .label_name ,
1266
1263
)
1267
1264
all_labels_embed = self ._tf_layers [f"embed.{ LABEL } " ](x )
1268
1265
@@ -1356,23 +1353,13 @@ def _calculate_entity_loss(
1356
1353
1357
1354
return loss , f1
1358
1355
1359
- @staticmethod
1360
- def _compute_mask (sequence_lengths : tf .Tensor ) -> tf .Tensor :
1361
- mask = tf .sequence_mask (sequence_lengths , dtype = tf .float32 )
1362
- # explicitly add last dimension to mask
1363
- # to track correctly dynamic sequences
1364
- return tf .expand_dims (mask , - 1 )
1365
-
1366
- def sequence_lengths_for (self , sequence_lengths : tf .Tensor ) -> tf .Tensor :
1367
- return tf .cast (sequence_lengths , dtype = tf .int32 )
1368
-
1369
1356
def batch_loss (
1370
1357
self , batch_in : Union [Tuple [tf .Tensor ], Tuple [np .ndarray ]]
1371
1358
) -> tf .Tensor :
1372
1359
tf_batch_data = self .batch_to_model_data_format (batch_in , self .data_signature )
1373
1360
1374
- sequence_lengths = self . sequence_lengths_for ( tf_batch_data [TEXT_SEQ_LENGTH ][0 ])
1375
- mask_text = self ._compute_mask ( sequence_lengths )
1361
+ mask_text = tf_batch_data [TEXT_MASK ][0 ]
1362
+ sequence_lengths = self ._get_sequence_lengths ( mask_text )
1376
1363
1377
1364
(
1378
1365
text_transformed ,
@@ -1401,14 +1388,11 @@ def batch_loss(
1401
1388
# get _cls_ vector for intent classification
1402
1389
cls = self ._last_token (text_transformed , sequence_lengths )
1403
1390
1404
- label_lengths = self .sequence_lengths_for (
1405
- tf_batch_data [LABEL_SEQ_LENGTH ][0 ]
1406
- )
1407
- mask_label = self ._compute_mask (label_lengths )
1408
-
1409
1391
label_ids = tf_batch_data [LABEL_IDS ][0 ]
1410
1392
label = self ._create_bow (
1411
- tf_batch_data [LABEL_FEATURES ], mask_label , self .label_name ,
1393
+ tf_batch_data [LABEL_FEATURES ],
1394
+ tf_batch_data [LABEL_MASK ][0 ],
1395
+ self .label_name ,
1412
1396
)
1413
1397
loss , acc = self ._calculate_label_loss (cls , label , label_ids )
1414
1398
self .intent_loss .update_state (loss )
@@ -1434,8 +1418,8 @@ def batch_predict(
1434
1418
batch_in , self .predict_data_signature
1435
1419
)
1436
1420
1437
- sequence_lengths = self . sequence_lengths_for ( tf_batch_data [TEXT_SEQ_LENGTH ][0 ])
1438
- mask_text = self ._compute_mask ( sequence_lengths )
1421
+ mask_text = tf_batch_data [TEXT_MASK ][0 ]
1422
+ sequence_lengths = self ._get_sequence_lengths ( mask_text )
1439
1423
1440
1424
text_transformed , _ , _ , _ = self ._create_sequence (
1441
1425
tf_batch_data [TEXT_FEATURES ], mask_text , self .text_name
0 commit comments