Skip to content

Commit 7276b08

Browse files
committed
refactor cli tests
1 parent da3c49e commit 7276b08

File tree

5 files changed

+100
-32
lines changed

5 files changed

+100
-32
lines changed

tests/cli/conftest.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from typing import Callable
22
import pytest
3+
import shutil
34
import os
45
from _pytest.pytester import Testdir, RunResult
56

7+
from rasa.utils.io import write_yaml_file
8+
69

710
@pytest.fixture
811
def run(testdir: Testdir) -> Callable[..., RunResult]:
@@ -22,13 +25,60 @@ def do_run(*args, stdin):
2225
return do_run
2326

2427

28+
@pytest.fixture
29+
def run_in_default_project_without_models(testdir: Testdir) -> Callable[..., RunResult]:
30+
os.environ["LOG_LEVEL"] = "ERROR"
31+
32+
_set_up_initial_project(testdir)
33+
34+
def do_run(*args):
35+
args = ["rasa"] + list(args)
36+
return testdir.run(*args)
37+
38+
return do_run
39+
40+
2541
@pytest.fixture
2642
def run_in_default_project(testdir: Testdir) -> Callable[..., RunResult]:
2743
os.environ["LOG_LEVEL"] = "ERROR"
28-
testdir.run("rasa", "init", "--no-prompt")
44+
45+
_set_up_initial_project(testdir)
46+
47+
testdir.run("rasa", "train")
2948

3049
def do_run(*args):
3150
args = ["rasa"] + list(args)
3251
return testdir.run(*args)
3352

3453
return do_run
54+
55+
56+
def _set_up_initial_project(testdir: Testdir):
57+
# copy initial project files
58+
testdir.copy_example("rasa/cli/initial_project/actions.py")
59+
testdir.copy_example("rasa/cli/initial_project/credentials.yml")
60+
testdir.copy_example("rasa/cli/initial_project/domain.yml")
61+
testdir.copy_example("rasa/cli/initial_project/endpoints.yml")
62+
testdir.mkdir("data")
63+
testdir.copy_example("rasa/cli/initial_project/data")
64+
testdir.run("mv", "nlu.md", "data/nlu.md")
65+
testdir.run("mv", "stories.md", "data/stories.md")
66+
67+
# create a config file
68+
# for the cli test the resulting model is not important, use components that are
69+
# fast to train
70+
write_yaml_file(
71+
{
72+
"language": "en",
73+
"pipeline": [
74+
{"name": "WhitespaceTokenizer"},
75+
{"name": "CountVectorsFeaturizer"},
76+
{"name": "KeywordIntentClassifier"},
77+
],
78+
"policies": [
79+
{"name": "MappingPolicy"},
80+
{"name": "MemoizationPolicy", "max_history": 5},
81+
],
82+
},
83+
"config.yml",
84+
)

tests/cli/test_rasa_data.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from rasa.cli import data
77

88

9-
def test_data_split_nlu(run_in_default_project: Callable[..., RunResult]):
10-
run_in_default_project(
9+
def test_data_split_nlu(
10+
run_in_default_project_without_models: Callable[..., RunResult]
11+
):
12+
run_in_default_project_without_models(
1113
"data", "split", "nlu", "-u", "data/nlu.md", "--training-fraction", "0.75"
1214
)
1315

@@ -16,8 +18,10 @@ def test_data_split_nlu(run_in_default_project: Callable[..., RunResult]):
1618
assert os.path.exists(os.path.join("train_test_split", "training_data.md"))
1719

1820

19-
def test_data_convert_nlu(run_in_default_project: Callable[..., RunResult]):
20-
run_in_default_project(
21+
def test_data_convert_nlu(
22+
run_in_default_project_without_models: Callable[..., RunResult]
23+
):
24+
run_in_default_project_without_models(
2125
"data",
2226
"convert",
2327
"nlu",

tests/cli/test_rasa_run.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from _pytest.pytester import RunResult
55

66

7-
def test_run_does_not_start(run_in_default_project: Callable[..., RunResult]):
7+
def test_run_does_not_start(
8+
run_in_default_project_without_models: Callable[..., RunResult]
9+
):
810
os.remove("domain.yml")
9-
shutil.rmtree("models")
1011

1112
# the server should not start as no model is configured
12-
output = run_in_default_project("run")
13+
output = run_in_default_project_without_models("run")
1314

1415
assert "No model found." in output.outlines[0]
1516

tests/cli/test_rasa_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,18 @@ def test_test_nlu_cross_validation(run_in_default_project: Callable[..., RunResu
5959

6060

6161
def test_test_nlu_comparison(run_in_default_project: Callable[..., RunResult]):
62-
copyfile("config.yml", "nlu-config.yml")
62+
copyfile("config.yml", "config-1.yml")
6363

6464
run_in_default_project(
65-
"test", "nlu", "-c", "config.yml", "nlu-config.yml", "--run", "2"
65+
"test",
66+
"nlu",
67+
"-config",
68+
"config.yml",
69+
"config-1.yml",
70+
"--run",
71+
"2",
72+
"-percentages",
73+
"75",
6674
)
6775

6876
assert os.path.exists("results/run_1")
@@ -106,6 +114,7 @@ def test_test_core_comparison_after_train(
106114
},
107115
"config_2.yml",
108116
)
117+
109118
run_in_default_project(
110119
"train",
111120
"core",

tests/cli/test_rasa_train.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import rasa.utils.io as io_utils
1919

2020

21-
def test_train(run_in_default_project: Callable[..., RunResult]):
21+
def test_train(run_in_default_project_without_models: Callable[..., RunResult]):
2222
temp_dir = os.getcwd()
2323

24-
run_in_default_project(
24+
run_in_default_project_without_models(
2525
"train",
2626
"-c",
2727
"config.yml",
@@ -48,10 +48,12 @@ def test_train(run_in_default_project: Callable[..., RunResult]):
4848
)
4949

5050

51-
def test_train_persist_nlu_data(run_in_default_project: Callable[..., RunResult]):
51+
def test_train_persist_nlu_data(
52+
run_in_default_project_without_models: Callable[..., RunResult]
53+
):
5254
temp_dir = os.getcwd()
5355

54-
run_in_default_project(
56+
run_in_default_project_without_models(
5557
"train",
5658
"-c",
5759
"config.yml",
@@ -79,7 +81,9 @@ def test_train_persist_nlu_data(run_in_default_project: Callable[..., RunResult]
7981
)
8082

8183

82-
def test_train_core_compare(run_in_default_project: Callable[..., RunResult]):
84+
def test_train_core_compare(
85+
run_in_default_project_without_models: Callable[..., RunResult]
86+
):
8387
temp_dir = os.getcwd()
8488

8589
io_utils.write_yaml_file(
@@ -100,7 +104,7 @@ def test_train_core_compare(run_in_default_project: Callable[..., RunResult]):
100104
"config_2.yml",
101105
)
102106

103-
run_in_default_project(
107+
run_in_default_project_without_models(
104108
"train",
105109
"core",
106110
"-c",
@@ -132,11 +136,11 @@ def test_train_core_compare(run_in_default_project: Callable[..., RunResult]):
132136

133137

134138
def test_train_no_domain_exists(
135-
run_in_default_project: Callable[..., RunResult]
139+
run_in_default_project_without_models: Callable[..., RunResult]
136140
) -> None:
137141

138142
os.remove("domain.yml")
139-
run_in_default_project(
143+
run_in_default_project_without_models(
140144
"train",
141145
"-c",
142146
"config.yml",
@@ -191,38 +195,36 @@ def test_train_force(run_in_default_project):
191195
assert len(files) == 2
192196

193197

194-
def test_train_with_only_nlu_data(run_in_default_project):
198+
def test_train_with_only_nlu_data(run_in_default_project_without_models):
195199
temp_dir = os.getcwd()
196200

197201
assert os.path.exists(os.path.join(temp_dir, "data/stories.md"))
198202
os.remove(os.path.join(temp_dir, "data/stories.md"))
199-
shutil.rmtree(os.path.join(temp_dir, "models"))
200203

201-
run_in_default_project("train", "--fixed-model-name", "test-model")
204+
run_in_default_project_without_models("train", "--fixed-model-name", "test-model")
202205

203206
assert os.path.exists(os.path.join(temp_dir, "models"))
204207
files = io_utils.list_files(os.path.join(temp_dir, "models"))
205208
assert len(files) == 1
206209
assert os.path.basename(files[0]) == "test-model.tar.gz"
207210

208211

209-
def test_train_with_only_core_data(run_in_default_project):
212+
def test_train_with_only_core_data(run_in_default_project_without_models):
210213
temp_dir = os.getcwd()
211214

212215
assert os.path.exists(os.path.join(temp_dir, "data/nlu.md"))
213216
os.remove(os.path.join(temp_dir, "data/nlu.md"))
214-
shutil.rmtree(os.path.join(temp_dir, "models"))
215217

216-
run_in_default_project("train", "--fixed-model-name", "test-model")
218+
run_in_default_project_without_models("train", "--fixed-model-name", "test-model")
217219

218220
assert os.path.exists(os.path.join(temp_dir, "models"))
219221
files = io_utils.list_files(os.path.join(temp_dir, "models"))
220222
assert len(files) == 1
221223
assert os.path.basename(files[0]) == "test-model.tar.gz"
222224

223225

224-
def test_train_core(run_in_default_project: Callable[..., RunResult]):
225-
run_in_default_project(
226+
def test_train_core(run_in_default_project_without_models: Callable[..., RunResult]):
227+
run_in_default_project_without_models(
226228
"train",
227229
"core",
228230
"-c",
@@ -241,10 +243,12 @@ def test_train_core(run_in_default_project: Callable[..., RunResult]):
241243
assert os.path.isfile("train_rasa_models/rasa-model.tar.gz")
242244

243245

244-
def test_train_core_no_domain_exists(run_in_default_project: Callable[..., RunResult]):
246+
def test_train_core_no_domain_exists(
247+
run_in_default_project_without_models: Callable[..., RunResult]
248+
):
245249

246250
os.remove("domain.yml")
247-
run_in_default_project(
251+
run_in_default_project_without_models(
248252
"train",
249253
"core",
250254
"--config",
@@ -263,8 +267,8 @@ def test_train_core_no_domain_exists(run_in_default_project: Callable[..., RunRe
263267
assert not os.path.isfile("train_rasa_models_no_domain/rasa-model.tar.gz")
264268

265269

266-
def test_train_nlu(run_in_default_project: Callable[..., RunResult]):
267-
run_in_default_project(
270+
def test_train_nlu(run_in_default_project_without_models: Callable[..., RunResult]):
271+
run_in_default_project_without_models(
268272
"train",
269273
"nlu",
270274
"-c",
@@ -289,9 +293,9 @@ def test_train_nlu(run_in_default_project: Callable[..., RunResult]):
289293

290294

291295
def test_train_nlu_persist_nlu_data(
292-
run_in_default_project: Callable[..., RunResult]
296+
run_in_default_project_without_models: Callable[..., RunResult]
293297
) -> None:
294-
run_in_default_project(
298+
run_in_default_project_without_models(
295299
"train",
296300
"nlu",
297301
"-c",

0 commit comments

Comments
 (0)