forked from RasaHQ/rasa_core
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
80 lines (60 loc) · 2.41 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
import matplotlib
import pytest
from rasa_core.agent import Agent
from rasa_core.channels.console import ConsoleOutputChannel
from rasa_core.channels.direct import CollectingOutputChannel
from rasa_core.dispatcher import Dispatcher
from rasa_core.domain import TemplateDomain
from rasa_core.featurizers import BinaryFeaturizer
from rasa_core.interpreter import RegexInterpreter
from rasa_core.policies import PolicyTrainer
from rasa_core.policies.ensemble import SimplePolicyEnsemble
from rasa_core.policies.memoization import MemoizationPolicy
from rasa_core.policies.scoring_policy import ScoringPolicy
from rasa_core.processor import MessageProcessor
from rasa_core.slots import Slot
from rasa_core.tracker_store import InMemoryTrackerStore
matplotlib.use('Agg')
logging.basicConfig(level="DEBUG")
pytest_plugins = str("pytest_twisted")
DEFAULT_DOMAIN_PATH = "data/test_domains/default_with_slots.yml"
DEFAULT_STORIES_FILE = "data/test_stories/stories_defaultdomain.md"
class CustomSlot(Slot):
def as_feature(self):
return [0.5]
@pytest.fixture(scope="session")
def default_domain():
return TemplateDomain.load(DEFAULT_DOMAIN_PATH)
@pytest.fixture(scope="session")
def default_agent(default_domain):
agent = Agent(default_domain,
policies=[MemoizationPolicy()],
interpreter=RegexInterpreter(),
tracker_store=InMemoryTrackerStore(default_domain))
agent.train(DEFAULT_STORIES_FILE)
return agent
@pytest.fixture
def default_dispatcher_cmd(default_domain):
bot = ConsoleOutputChannel()
return Dispatcher("my-sender", bot, default_domain)
@pytest.fixture
def default_dispatcher_collecting(default_domain):
bot = CollectingOutputChannel()
return Dispatcher("my-sender", bot, default_domain)
@pytest.fixture
def default_processor(default_domain):
ensemble = SimplePolicyEnsemble([ScoringPolicy()])
interpreter = RegexInterpreter()
PolicyTrainer(ensemble, default_domain, BinaryFeaturizer()).train(
DEFAULT_STORIES_FILE,
max_history=3)
tracker_store = InMemoryTrackerStore(default_domain)
return MessageProcessor(interpreter,
ensemble,
default_domain,
tracker_store)