forked from artificial-nikhita/starter-pack-rasa-stack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from RasaHQ/add-tests
- Decided against testing the makefile, since we are not likely to make dramatic changes to the way we parse arguments from the command line. It would also involve training another model which would increase build time. - Since this starter pack is made for newcomers, we would like to keep obfuscation low and a `dev-requirements.txt` would hinder this, so no `pycodestyle` testing either
- Loading branch information
Showing
4 changed files
with
79 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
language: python | ||
sudo: enabled | ||
cache: | ||
directories: | ||
- $HOME/.cache/pip | ||
- /tmp/cached/ | ||
python: | ||
- '3.5' | ||
- '3.6' | ||
install: | ||
- pip install -r requirements.txt | ||
- pip install pytest==3.5.1 | ||
- pip install rasa_nlu --upgrade | ||
- pip install rasa_core --upgrade | ||
- pip install rasa_core_sdk --upgrade | ||
- python -m spacy download en_core_web_md | ||
- python -m spacy link en_core_web_md en | ||
script: | ||
- py.test test_stack.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
language: "en" | ||
|
||
pipeline: spacy_sklearn | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from rasa_core.policies import FallbackPolicy, KerasPolicy, MemoizationPolicy | ||
from rasa_core.agent import Agent | ||
|
||
from rasa_nlu.training_data import load_data | ||
from rasa_nlu.model import Trainer | ||
from rasa_nlu import config as nlu_config | ||
|
||
from rasa_core import config as core_config | ||
from rasa_core.trackers import DialogueStateTracker | ||
from rasa_core.domain import Domain | ||
from rasa_core.dispatcher import Dispatcher | ||
from rasa_core.channels import CollectingOutputChannel | ||
from rasa_core.nlg import TemplatedNaturalLanguageGenerator | ||
from actions import ActionJoke | ||
import uuid | ||
|
||
|
||
def test_nlu_interpreter(): | ||
training_data = load_data("data/nlu_data.md") | ||
trainer = Trainer(nlu_config.load("nlu_config.yml")) | ||
interpreter = trainer.train(training_data) | ||
test_interpreter_dir = trainer.persist("./models/nlu", fixed_model_name="test") | ||
parsing = interpreter.parse('hello') | ||
|
||
assert parsing['intent']['name'] == 'greet' | ||
assert test_interpreter_dir | ||
|
||
|
||
def test_agent_and_persist(): | ||
policies = core_config.load('policies.yml') | ||
policies[0] = KerasPolicy(epochs=2) # Keep training times low | ||
|
||
agent = Agent('domain.yml', policies=policies) | ||
training_data = agent.load_data('data/stories.md') | ||
agent.train(training_data, validation_split=0.0) | ||
agent.persist('models/dialogue') | ||
|
||
loaded = Agent.load('models/dialogue') | ||
|
||
assert agent.handle_text('/greet') is not None | ||
assert loaded.domain.action_names == agent.domain.action_names | ||
assert loaded.domain.intents == agent.domain.intents | ||
assert loaded.domain.entities == agent.domain.entities | ||
assert loaded.domain.templates == agent.domain.templates | ||
|
||
|
||
def test_action(): | ||
domain = Domain.load('domain.yml') | ||
nlg = TemplatedNaturalLanguageGenerator(domain.templates) | ||
dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg) | ||
uid = str(uuid.uuid1()) | ||
tracker = DialogueStateTracker(uid, domain.slots) | ||
|
||
action = ActionJoke() | ||
action.run(dispatcher, tracker, domain) | ||
|
||
assert dispatcher.output_channel.latest_output() is not None |