Skip to content

Commit 5c6a6e9

Browse files
authored
Merge pull request RasaHQ#4743 from RasaHQ/fix_intent_featurizer
use only word level featurizer for intents
2 parents 6f920ba + cdac9cb commit 5c6a6e9

File tree

3 files changed

+123
-90
lines changed

3 files changed

+123
-90
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Fixed
3232
- ``MultiProjectImporter`` now imports files in the order of the import statements
3333
- Fixed server hanging forever on leaving ``rasa shell`` before first message
3434
- Fixed rasa init showing traceback error when user does Keyboard Interrupt before choosing a project path
35+
- ``CountVectorsFeaturizer`` featurizes intents only if its analyzer is set to ``word``
3536

3637
[1.4.2] - 2019-10-28
3738
^^^^^^^^^^^^^^^^^^^^

rasa/nlu/featurizers/count_vectors_featurizer.py

Lines changed: 104 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,6 @@ def _get_attribute_vocabulary(self, attribute: Text) -> Optional[Dict[Text, int]
153153
except (AttributeError, TypeError):
154154
return None
155155

156-
def _collect_vectorizer_vocabularies(self):
157-
"""Get vocabulary for all attributes"""
158-
159-
attribute_vocabularies = {}
160-
for attribute in MESSAGE_ATTRIBUTES:
161-
attribute_vocabularies[attribute] = self._get_attribute_vocabulary(
162-
attribute
163-
)
164-
return attribute_vocabularies
165-
166156
def _get_attribute_vocabulary_tokens(self, attribute: Text) -> Optional[List[Text]]:
167157
"""Get all keys of vocabulary of an attribute"""
168158

@@ -192,6 +182,15 @@ def _check_analyzer(self):
192182
"contain single letters only."
193183
)
194184

185+
@staticmethod
186+
def _attributes(analyzer):
187+
"""Create a list of attributes that should be featurized."""
188+
189+
# intents should be featurized only by word level count vectorizer
190+
return (
191+
MESSAGE_ATTRIBUTES if analyzer == "word" else SPACY_FEATURIZABLE_ATTRIBUTES
192+
)
193+
195194
def __init__(
196195
self,
197196
component_config: Dict[Text, Any] = None,
@@ -210,6 +209,9 @@ def __init__(
210209
# warn that some of config parameters might be ignored
211210
self._check_analyzer()
212211

212+
# set which attributes to featurize
213+
self._attributes = self._attributes(self.analyzer)
214+
213215
# declare class instance for CountVectorizer
214216
self.vectorizers = vectorizers
215217

@@ -335,7 +337,7 @@ def _get_all_attributes_processed_texts(
335337
"""Get processed text for all attributes of examples in training data"""
336338

337339
processed_attribute_texts = {}
338-
for attribute in MESSAGE_ATTRIBUTES:
340+
for attribute in self._attributes:
339341
attribute_texts = [
340342
self._get_message_text_by_attribute(example, attribute)
341343
for example in training_data.intent_examples
@@ -344,82 +346,10 @@ def _get_all_attributes_processed_texts(
344346
processed_attribute_texts[attribute] = attribute_texts
345347
return processed_attribute_texts
346348

347-
@staticmethod
348-
def create_shared_vocab_vectorizers(
349-
token_pattern,
350-
strip_accents,
351-
lowercase,
352-
stop_words,
353-
ngram_range,
354-
max_df,
355-
min_df,
356-
max_features,
357-
analyzer,
358-
vocabulary=None,
359-
) -> Dict[Text, "CountVectorizer"]:
360-
"""Create vectorizers for all attributes with shared vocabulary"""
361-
362-
shared_vectorizer = CountVectorizer(
363-
token_pattern=token_pattern,
364-
strip_accents=strip_accents,
365-
lowercase=lowercase,
366-
stop_words=stop_words,
367-
ngram_range=ngram_range,
368-
max_df=max_df,
369-
min_df=min_df,
370-
max_features=max_features,
371-
analyzer=analyzer,
372-
vocabulary=vocabulary,
373-
)
374-
375-
attribute_vectorizers = {}
376-
377-
for attribute in MESSAGE_ATTRIBUTES:
378-
attribute_vectorizers[attribute] = shared_vectorizer
379-
380-
return attribute_vectorizers
381-
382-
@staticmethod
383-
def create_independent_vocab_vectorizers(
384-
token_pattern,
385-
strip_accents,
386-
lowercase,
387-
stop_words,
388-
ngram_range,
389-
max_df,
390-
min_df,
391-
max_features,
392-
analyzer,
393-
vocabulary=None,
394-
) -> Dict[Text, "CountVectorizer"]:
395-
"""Create vectorizers for all attributes with independent vocabulary"""
396-
397-
attribute_vectorizers = {}
398-
399-
for attribute in MESSAGE_ATTRIBUTES:
400-
401-
attribute_vocabulary = vocabulary[attribute] if vocabulary else None
402-
403-
attribute_vectorizer = CountVectorizer(
404-
token_pattern=token_pattern,
405-
strip_accents=strip_accents,
406-
lowercase=lowercase,
407-
stop_words=stop_words,
408-
ngram_range=ngram_range,
409-
max_df=max_df,
410-
min_df=min_df,
411-
max_features=max_features,
412-
analyzer=analyzer,
413-
vocabulary=attribute_vocabulary,
414-
)
415-
attribute_vectorizers[attribute] = attribute_vectorizer
416-
417-
return attribute_vectorizers
418-
419349
def _train_with_shared_vocab(self, attribute_texts: Dict[Text, List[Text]]):
420350
"""Construct the vectorizers and train them with a shared vocab"""
421351

422-
self.vectorizers = self.create_shared_vocab_vectorizers(
352+
self.vectorizers = self._create_shared_vocab_vectorizers(
423353
self.token_pattern,
424354
self.strip_accents,
425355
self.lowercase,
@@ -432,7 +362,7 @@ def _train_with_shared_vocab(self, attribute_texts: Dict[Text, List[Text]]):
432362
)
433363

434364
combined_cleaned_texts = []
435-
for attribute in MESSAGE_ATTRIBUTES:
365+
for attribute in self._attributes:
436366
combined_cleaned_texts += attribute_texts[attribute]
437367

438368
try:
@@ -449,7 +379,7 @@ def _attribute_texts_is_non_empty(attribute_texts):
449379
def _train_with_independent_vocab(self, attribute_texts: Dict[Text, List[Text]]):
450380
"""Construct the vectorizers and train them with an independent vocab"""
451381

452-
self.vectorizers = self.create_independent_vocab_vectorizers(
382+
self.vectorizers = self._create_independent_vocab_vectorizers(
453383
self.token_pattern,
454384
self.strip_accents,
455385
self.lowercase,
@@ -461,7 +391,7 @@ def _train_with_independent_vocab(self, attribute_texts: Dict[Text, List[Text]])
461391
self.analyzer,
462392
)
463393

464-
for attribute in MESSAGE_ATTRIBUTES:
394+
for attribute in self._attributes:
465395
if self._attribute_texts_is_non_empty(attribute_texts[attribute]):
466396
try:
467397
self.vectorizers[attribute].fit(attribute_texts[attribute])
@@ -516,7 +446,7 @@ def train(
516446
self._train_with_independent_vocab(processed_attribute_texts)
517447

518448
# transform for all attributes
519-
for attribute in MESSAGE_ATTRIBUTES:
449+
for attribute in self._attributes:
520450

521451
attribute_features = self._get_featurized_attribute(
522452
attribute, processed_attribute_texts[attribute]
@@ -556,6 +486,16 @@ def process(self, message: Message, **kwargs: Any) -> None:
556486
),
557487
)
558488

489+
def _collect_vectorizer_vocabularies(self):
490+
"""Get vocabulary for all attributes"""
491+
492+
attribute_vocabularies = {}
493+
for attribute in self._attributes:
494+
attribute_vocabularies[attribute] = self._get_attribute_vocabulary(
495+
attribute
496+
)
497+
return attribute_vocabularies
498+
559499
@staticmethod
560500
def _is_any_model_trained(attribute_vocabularies) -> bool:
561501
"""Check if any model got trained"""
@@ -586,6 +526,80 @@ def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]
586526
utils.json_pickle(featurizer_file, attribute_vocabularies)
587527
return {"file": file_name}
588528

529+
@classmethod
530+
def _create_shared_vocab_vectorizers(
531+
cls,
532+
token_pattern,
533+
strip_accents,
534+
lowercase,
535+
stop_words,
536+
ngram_range,
537+
max_df,
538+
min_df,
539+
max_features,
540+
analyzer,
541+
vocabulary=None,
542+
) -> Dict[Text, "CountVectorizer"]:
543+
"""Create vectorizers for all attributes with shared vocabulary"""
544+
545+
shared_vectorizer = CountVectorizer(
546+
token_pattern=token_pattern,
547+
strip_accents=strip_accents,
548+
lowercase=lowercase,
549+
stop_words=stop_words,
550+
ngram_range=ngram_range,
551+
max_df=max_df,
552+
min_df=min_df,
553+
max_features=max_features,
554+
analyzer=analyzer,
555+
vocabulary=vocabulary,
556+
)
557+
558+
attribute_vectorizers = {}
559+
560+
for attribute in cls._attributes(analyzer):
561+
attribute_vectorizers[attribute] = shared_vectorizer
562+
563+
return attribute_vectorizers
564+
565+
@classmethod
566+
def _create_independent_vocab_vectorizers(
567+
cls,
568+
token_pattern,
569+
strip_accents,
570+
lowercase,
571+
stop_words,
572+
ngram_range,
573+
max_df,
574+
min_df,
575+
max_features,
576+
analyzer,
577+
vocabulary=None,
578+
) -> Dict[Text, "CountVectorizer"]:
579+
"""Create vectorizers for all attributes with independent vocabulary"""
580+
581+
attribute_vectorizers = {}
582+
583+
for attribute in cls._attributes(analyzer):
584+
585+
attribute_vocabulary = vocabulary[attribute] if vocabulary else None
586+
587+
attribute_vectorizer = CountVectorizer(
588+
token_pattern=token_pattern,
589+
strip_accents=strip_accents,
590+
lowercase=lowercase,
591+
stop_words=stop_words,
592+
ngram_range=ngram_range,
593+
max_df=max_df,
594+
min_df=min_df,
595+
max_features=max_features,
596+
analyzer=analyzer,
597+
vocabulary=attribute_vocabulary,
598+
)
599+
attribute_vectorizers[attribute] = attribute_vectorizer
600+
601+
return attribute_vectorizers
602+
589603
@classmethod
590604
def load(
591605
cls,
@@ -605,7 +619,7 @@ def load(
605619
share_vocabulary = meta["use_shared_vocab"]
606620

607621
if share_vocabulary:
608-
vectorizers = cls.create_shared_vocab_vectorizers(
622+
vectorizers = cls._create_shared_vocab_vectorizers(
609623
token_pattern=meta["token_pattern"],
610624
strip_accents=meta["strip_accents"],
611625
lowercase=meta["lowercase"],
@@ -618,7 +632,7 @@ def load(
618632
vocabulary=vocabulary,
619633
)
620634
else:
621-
vectorizers = cls.create_independent_vocab_vectorizers(
635+
vectorizers = cls._create_independent_vocab_vectorizers(
622636
token_pattern=meta["token_pattern"],
623637
strip_accents=meta["strip_accents"],
624638
lowercase=meta["lowercase"],

tests/nlu/base/test_featurizers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,24 @@ def test_count_vector_featurizer_char(sentence, expected):
461461
assert np.all(test_message.get("text_features") == expected)
462462

463463

464+
def test_count_vector_featurizer_char_intent_featurizer():
465+
from rasa.nlu.featurizers.count_vectors_featurizer import CountVectorsFeaturizer
466+
467+
ftr = CountVectorsFeaturizer({"min_ngram": 1, "max_ngram": 2, "analyzer": "char"})
468+
td = training_data.load_data("data/examples/rasa/demo-rasa.json")
469+
ftr.train(td, config=None)
470+
471+
intent_features_exist = np.array(
472+
[
473+
True if example.get("intent_features") is not None else False
474+
for example in td.intent_examples
475+
]
476+
)
477+
478+
# no intent features should have been set
479+
assert not any(intent_features_exist)
480+
481+
464482
def test_count_vector_featurizer_persist_load(tmpdir):
465483
from rasa.nlu.featurizers.count_vectors_featurizer import CountVectorsFeaturizer
466484

0 commit comments

Comments
 (0)