|
| 1 | +import logging |
| 2 | +import typing |
| 3 | +from typing import Any, Dict, List, Optional, Text |
| 4 | + |
| 5 | +from rasa_nlu.components import Component |
| 6 | +from rasa_nlu.config import RasaNLUModelConfig |
| 7 | +from rasa_nlu.training_data import Message, TrainingData |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | +if typing.TYPE_CHECKING: |
| 12 | + from spacy.language import Language |
| 13 | + from spacy.tokens.doc import Doc |
| 14 | + from rasa_nlu.model import Metadata |
| 15 | + |
| 16 | + |
| 17 | +class SpacyNLP(Component): |
| 18 | + name = "nlp_spacy" |
| 19 | + |
| 20 | + provides = ["spacy_doc", "spacy_nlp"] |
| 21 | + |
| 22 | + defaults = { |
| 23 | + # name of the language model to load - if it is not set |
| 24 | + # we will be looking for a language model that is named |
| 25 | + # after the language of the model, e.g. `en` |
| 26 | + "model": None, |
| 27 | + |
| 28 | + # when retrieving word vectors, this will decide if the casing |
| 29 | + # of the word is relevant. E.g. `hello` and `Hello` will |
| 30 | + # retrieve the same vector, if set to `False`. For some |
| 31 | + # applications and models it makes sense to differentiate |
| 32 | + # between these two words, therefore setting this to `True`. |
| 33 | + "case_sensitive": False, |
| 34 | + } |
| 35 | + |
| 36 | + def __init__(self, |
| 37 | + component_config: Dict[Text, Any] = None, |
| 38 | + nlp: 'Language' = None) -> None: |
| 39 | + |
| 40 | + self.nlp = nlp |
| 41 | + super(SpacyNLP, self).__init__(component_config) |
| 42 | + |
| 43 | + @classmethod |
| 44 | + def required_packages(cls) -> List[Text]: |
| 45 | + return ["spacy"] |
| 46 | + |
| 47 | + @classmethod |
| 48 | + def create(cls, cfg: RasaNLUModelConfig) -> 'SpacyNLP': |
| 49 | + import spacy |
| 50 | + |
| 51 | + component_conf = cfg.for_component(cls.name, cls.defaults) |
| 52 | + spacy_model_name = component_conf.get("model") |
| 53 | + |
| 54 | + # if no model is specified, we fall back to the language string |
| 55 | + if not spacy_model_name: |
| 56 | + spacy_model_name = cfg.language |
| 57 | + component_conf["model"] = cfg.language |
| 58 | + |
| 59 | + logger.info("Trying to load spacy model with " |
| 60 | + "name '{}'".format(spacy_model_name)) |
| 61 | + |
| 62 | + nlp = spacy.load(spacy_model_name, disable=['parser']) |
| 63 | + cls.ensure_proper_language_model(nlp) |
| 64 | + return SpacyNLP(component_conf, nlp) |
| 65 | + |
| 66 | + @classmethod |
| 67 | + def cache_key(cls, model_metadata: 'Metadata') -> Text: |
| 68 | + |
| 69 | + component_meta = model_metadata.for_component(cls.name) |
| 70 | + |
| 71 | + # Fallback, use the language name, e.g. "en", |
| 72 | + # as the model name if no explicit name is defined |
| 73 | + spacy_model_name = component_meta.get("model", model_metadata.language) |
| 74 | + |
| 75 | + return cls.name + "-" + spacy_model_name |
| 76 | + |
| 77 | + def provide_context(self) -> Dict[Text, Any]: |
| 78 | + return {"spacy_nlp": self.nlp} |
| 79 | + |
| 80 | + def doc_for_text(self, text: Text) -> 'Doc': |
| 81 | + if self.component_config.get("case_sensitive"): |
| 82 | + return self.nlp(text) |
| 83 | + else: |
| 84 | + return self.nlp(text.lower()) |
| 85 | + |
| 86 | + def train(self, |
| 87 | + training_data: TrainingData, |
| 88 | + config: RasaNLUModelConfig, |
| 89 | + **kwargs: Any) -> None: |
| 90 | + |
| 91 | + for example in training_data.training_examples: |
| 92 | + example.set("spacy_doc", self.doc_for_text(example.text)) |
| 93 | + |
| 94 | + def process(self, message: Message, **kwargs: Any) -> None: |
| 95 | + |
| 96 | + message.set("spacy_doc", self.doc_for_text(message.text)) |
| 97 | + |
| 98 | + @classmethod |
| 99 | + def load(cls, |
| 100 | + model_dir: Text = None, |
| 101 | + model_metadata: 'Metadata' = None, |
| 102 | + cached_component: Optional['SpacyNLP'] = None, |
| 103 | + **kwargs: Any) -> 'SpacyNLP': |
| 104 | + import spacy |
| 105 | + |
| 106 | + if cached_component: |
| 107 | + return cached_component |
| 108 | + |
| 109 | + component_meta = model_metadata.for_component(cls.name) |
| 110 | + model_name = component_meta.get("model") |
| 111 | + |
| 112 | + nlp = spacy.load(model_name, disable=['parser']) |
| 113 | + cls.ensure_proper_language_model(nlp) |
| 114 | + return cls(component_meta, nlp) |
| 115 | + |
| 116 | + @staticmethod |
| 117 | + def ensure_proper_language_model(nlp: Optional['Language']) -> None: |
| 118 | + """Checks if the spacy language model is properly loaded. |
| 119 | + Raises an exception if the model is invalid.""" |
| 120 | + |
| 121 | + if nlp is None: |
| 122 | + raise Exception("Failed to load spacy language model. " |
| 123 | + "Loading the model returned 'None'.") |
| 124 | + if nlp.path is None: |
| 125 | + # Spacy sets the path to `None` if |
| 126 | + # it did not load the model from disk. |
| 127 | + # In this case `nlp` is an unusable stub. |
| 128 | + raise Exception("Failed to load spacy language model for " |
| 129 | + "lang '{}'. Make sure you have downloaded the " |
| 130 | + "correct model (https://spacy.io/docs/usage/)." |
| 131 | + "".format(nlp.lang)) |
0 commit comments