|
| 1 | +import logging |
| 2 | +from rasa.nlu.featurizers import Featurizer |
| 3 | +from typing import Any, Dict, List, Optional, Text, Tuple |
| 4 | +from rasa.nlu.config import RasaNLUModelConfig |
| 5 | +from rasa.nlu.training_data import Message, TrainingData |
| 6 | +from rasa.nlu.constants import ( |
| 7 | + MESSAGE_TEXT_ATTRIBUTE, |
| 8 | + MESSAGE_VECTOR_FEATURE_NAMES, |
| 9 | + SPACY_FEATURIZABLE_ATTRIBUTES, |
| 10 | +) |
| 11 | +import numpy as np |
| 12 | +import tensorflow as tf |
| 13 | + |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +class ConveRTFeaturizer(Featurizer): |
| 18 | + |
| 19 | + provides = [ |
| 20 | + MESSAGE_VECTOR_FEATURE_NAMES[attribute] |
| 21 | + for attribute in SPACY_FEATURIZABLE_ATTRIBUTES |
| 22 | + ] |
| 23 | + |
| 24 | + def _load_model(self) -> None: |
| 25 | + |
| 26 | + import tensorflow_text |
| 27 | + import tensorflow_hub as tfhub |
| 28 | + |
| 29 | + self.graph = tf.Graph() |
| 30 | + model_url = "http://models.poly-ai.com/convert/v1/model.tar.gz" |
| 31 | + |
| 32 | + with self.graph.as_default(): |
| 33 | + self.session = tf.Session() |
| 34 | + self.module = tfhub.Module(model_url) |
| 35 | + |
| 36 | + self.text_placeholder = tf.placeholder(dtype=tf.string, shape=[None]) |
| 37 | + self.encoding_tensor = self.module(self.text_placeholder) |
| 38 | + self.session.run(tf.tables_initializer()) |
| 39 | + self.session.run(tf.global_variables_initializer()) |
| 40 | + |
| 41 | + def __init__(self, component_config: Dict[Text, Any] = None) -> None: |
| 42 | + |
| 43 | + super(ConveRTFeaturizer, self).__init__(component_config) |
| 44 | + |
| 45 | + self._load_model() |
| 46 | + |
| 47 | + @classmethod |
| 48 | + def required_packages(cls) -> List[Text]: |
| 49 | + return ["tensorflow_text", "tensorflow_hub"] |
| 50 | + |
| 51 | + def _compute_features( |
| 52 | + self, batch_examples: List[Message], attribute: Text = MESSAGE_TEXT_ATTRIBUTE |
| 53 | + ) -> np.ndarray: |
| 54 | + |
| 55 | + # Get text for attribute of each example |
| 56 | + batch_attribute_text = [ex.get(attribute) for ex in batch_examples] |
| 57 | + |
| 58 | + batch_features = self._run_model_on_text(batch_attribute_text) |
| 59 | + |
| 60 | + return batch_features |
| 61 | + |
| 62 | + def _run_model_on_text(self, batch: List[Text]) -> np.ndarray: |
| 63 | + |
| 64 | + return self.session.run( |
| 65 | + self.encoding_tensor, feed_dict={self.text_placeholder: batch} |
| 66 | + ) |
| 67 | + |
| 68 | + def train( |
| 69 | + self, |
| 70 | + training_data: TrainingData, |
| 71 | + config: Optional[RasaNLUModelConfig], |
| 72 | + **kwargs: Any, |
| 73 | + ) -> None: |
| 74 | + |
| 75 | + batch_size = 64 |
| 76 | + |
| 77 | + for attribute in SPACY_FEATURIZABLE_ATTRIBUTES: |
| 78 | + |
| 79 | + non_empty_examples = list( |
| 80 | + filter(lambda x: x.get(attribute), training_data.training_examples) |
| 81 | + ) |
| 82 | + |
| 83 | + batch_start_index = 0 |
| 84 | + |
| 85 | + while batch_start_index < len(non_empty_examples): |
| 86 | + |
| 87 | + batch_end_index = min( |
| 88 | + batch_start_index + batch_size, len(non_empty_examples) |
| 89 | + ) |
| 90 | + |
| 91 | + # Collect batch examples |
| 92 | + batch_examples = non_empty_examples[batch_start_index:batch_end_index] |
| 93 | + |
| 94 | + batch_features = self._compute_features(batch_examples, attribute) |
| 95 | + |
| 96 | + for index, ex in enumerate(batch_examples): |
| 97 | + |
| 98 | + ex.set( |
| 99 | + MESSAGE_VECTOR_FEATURE_NAMES[attribute], |
| 100 | + self._combine_with_existing_features( |
| 101 | + ex, |
| 102 | + batch_features[index], |
| 103 | + MESSAGE_VECTOR_FEATURE_NAMES[attribute], |
| 104 | + ), |
| 105 | + ) |
| 106 | + |
| 107 | + batch_start_index += batch_size |
| 108 | + |
| 109 | + def process(self, message: Message, **kwargs: Any) -> None: |
| 110 | + |
| 111 | + feats = self._compute_features([message])[0] |
| 112 | + message.set( |
| 113 | + MESSAGE_VECTOR_FEATURE_NAMES[MESSAGE_TEXT_ATTRIBUTE], |
| 114 | + self._combine_with_existing_features( |
| 115 | + message, feats, MESSAGE_VECTOR_FEATURE_NAMES[MESSAGE_TEXT_ATTRIBUTE] |
| 116 | + ), |
| 117 | + ) |
0 commit comments