Skip to content

Commit

Permalink
Merge pull request #19 from RasaHQ/add-tests
Browse files Browse the repository at this point in the history
- Decided against testing the makefile, since we are not likely to make dramatic changes to the way we parse arguments from the command line.  It would also involve training another model which would increase build time.
- Since this starter pack is made for newcomers, we would like to keep obfuscation low and a `dev-requirements.txt` would hinder this, so no `pycodestyle` testing either
  • Loading branch information
MetcalfeTom authored Jan 3, 2019
2 parents 275e861 + 2427627 commit c106a32
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 4 deletions.
19 changes: 19 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
language: python
sudo: enabled
cache:
directories:
- $HOME/.cache/pip
- /tmp/cached/
python:
- '3.5'
- '3.6'
install:
- pip install -r requirements.txt
- pip install pytest==3.5.1
- pip install rasa_nlu --upgrade
- pip install rasa_core --upgrade
- pip install rasa_core_sdk --upgrade
- python -m spacy download en_core_web_md
- python -m spacy link en_core_web_md en
script:
- py.test test_stack.py
6 changes: 3 additions & 3 deletions actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def name(self):

def run(self, dispatcher, tracker, domain):
# what your action should do
request = json.loads(requests.get('https://api.chucknorris.io/jokes/random').text) #make an apie call
joke = request['value'] #extract a joke from returned json response
dispatcher.utter_message(joke) #send the message back to the user
request = json.loads(requests.get('https://api.chucknorris.io/jokes/random').text) # make an api call
joke = request['value'] # extract a joke from returned json response
dispatcher.utter_message(joke) # send the message back to the user
return []
1 change: 0 additions & 1 deletion nlu_config.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
language: "en"

pipeline: spacy_sklearn

57 changes: 57 additions & 0 deletions test_stack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from rasa_core.policies import FallbackPolicy, KerasPolicy, MemoizationPolicy
from rasa_core.agent import Agent

from rasa_nlu.training_data import load_data
from rasa_nlu.model import Trainer
from rasa_nlu import config as nlu_config

from rasa_core import config as core_config
from rasa_core.trackers import DialogueStateTracker
from rasa_core.domain import Domain
from rasa_core.dispatcher import Dispatcher
from rasa_core.channels import CollectingOutputChannel
from rasa_core.nlg import TemplatedNaturalLanguageGenerator
from actions import ActionJoke
import uuid


def test_nlu_interpreter():
training_data = load_data("data/nlu_data.md")
trainer = Trainer(nlu_config.load("nlu_config.yml"))
interpreter = trainer.train(training_data)
test_interpreter_dir = trainer.persist("./models/nlu", fixed_model_name="test")
parsing = interpreter.parse('hello')

assert parsing['intent']['name'] == 'greet'
assert test_interpreter_dir


def test_agent_and_persist():
policies = core_config.load('policies.yml')
policies[0] = KerasPolicy(epochs=2) # Keep training times low

agent = Agent('domain.yml', policies=policies)
training_data = agent.load_data('data/stories.md')
agent.train(training_data, validation_split=0.0)
agent.persist('models/dialogue')

loaded = Agent.load('models/dialogue')

assert agent.handle_text('/greet') is not None
assert loaded.domain.action_names == agent.domain.action_names
assert loaded.domain.intents == agent.domain.intents
assert loaded.domain.entities == agent.domain.entities
assert loaded.domain.templates == agent.domain.templates


def test_action():
domain = Domain.load('domain.yml')
nlg = TemplatedNaturalLanguageGenerator(domain.templates)
dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg)
uid = str(uuid.uuid1())
tracker = DialogueStateTracker(uid, domain.slots)

action = ActionJoke()
action.run(dispatcher, tracker, domain)

assert dispatcher.output_channel.latest_output() is not None

0 comments on commit c106a32

Please sign in to comment.