Skip to content

Commit 2d8c5aa

Browse files
committed
Add temp tests for train_core and train_nlu.
1 parent 070a3f7 commit 2d8c5aa

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

tests/test_train.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import rasa.model
1111

12-
from rasa.train import train
12+
from rasa.train import train_core, train_nlu, train
1313
from tests.core.test_model import _fingerprint
1414

1515

@@ -72,11 +72,13 @@ def test_train_temp_files(
7272
default_nlu_data: Text,
7373
):
7474
monkeypatch.setattr(tempfile, "tempdir", tmp_path)
75+
output = "test_train_temp_files_models"
7576

7677
train(
7778
default_domain_path,
7879
default_stack_config,
7980
[default_stories_file, default_nlu_data],
81+
output=output,
8082
force_training=True,
8183
)
8284

@@ -89,6 +91,45 @@ def test_train_temp_files(
8991
default_domain_path,
9092
default_stack_config,
9193
[default_stories_file, default_nlu_data],
94+
output=output,
95+
)
96+
97+
assert count_temp_rasa_files(tempfile.tempdir) == 0
98+
99+
100+
def test_train_core_temp_files(
101+
tmp_path: Text,
102+
monkeypatch: MonkeyPatch,
103+
default_domain_path: Text,
104+
default_stories_file: Text,
105+
default_stack_config: Text,
106+
):
107+
monkeypatch.setattr(tempfile, "tempdir", tmp_path)
108+
109+
train_core(
110+
default_domain_path,
111+
default_stack_config,
112+
default_stories_file,
113+
output="test_train_core_temp_files_models",
114+
)
115+
116+
assert count_temp_rasa_files(tempfile.tempdir) == 0
117+
118+
119+
def test_train_nlu_temp_files(
120+
tmp_path: Text,
121+
monkeypatch: MonkeyPatch,
122+
default_domain_path: Text,
123+
default_stories_file: Text,
124+
default_stack_config: Text,
125+
default_nlu_data: Text,
126+
):
127+
monkeypatch.setattr(tempfile, "tempdir", tmp_path)
128+
129+
train_nlu(
130+
default_stack_config,
131+
default_nlu_data,
132+
output="test_train_nlu_temp_files_models",
92133
)
93134

94135
assert count_temp_rasa_files(tempfile.tempdir) == 0

0 commit comments

Comments
 (0)