Skip to content

Commit 55caf93

Browse files
authored
Merge pull request RasaHQ#5122 from RasaHQ/issue4896
for `rasa test core --evaluate-model-directory` set default for `--model` to directory
2 parents c58ae0b + 302cc95 commit 55caf93

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

changelog/4896.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed default behavior of ``rasa test core --evaluate-model-directory`` when called without ``--model``. Previously, the latest model file was used as ``--model``. Now the default model directory is used instead.
2+
3+
New behavior of ``rasa test core --evaluate-model-directory`` when given an existing file as argument for ``--model``: Previously, this led to an error. Now a warning is displayed and the directory containing the given file is used as ``--model``.

rasa/test.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
from typing import Text, Dict, Optional, List, Any
5+
from pathlib import Path
56

67
import rasa.utils.io as io_utils
78
from rasa.constants import (
@@ -16,9 +17,13 @@
1617
logger = logging.getLogger(__name__)
1718

1819

19-
def test_core_models_in_directory(model_directory: Text, stories: Text, output: Text):
20+
def test_core_models_in_directory(
21+
model_directory: Text, stories: Text, output: Text
22+
) -> None:
2023
from rasa.core.test import compare_models_in_dir, plot_core_results
2124

25+
model_directory = _get_sanitized_model_directory(model_directory)
26+
2227
loop = asyncio.get_event_loop()
2328
loop.run_until_complete(compare_models_in_dir(model_directory, stories, output))
2429

@@ -27,6 +32,31 @@ def test_core_models_in_directory(model_directory: Text, stories: Text, output:
2732
plot_core_results(output, number_of_stories)
2833

2934

35+
def _get_sanitized_model_directory(model_directory: Text) -> Text:
36+
"""Adjusts the `--model` argument of `rasa test core` when called with `--evaluate-model-directory`.
37+
38+
By default rasa uses the latest model for the `--model` parameter. However, for `--evaluate-model-directory` we
39+
need a directory. This function checks if the passed parameter is a model or an individual model file.
40+
41+
Args:
42+
model_directory: The model_directory argument that was given to `test_core_models_in_directory`.
43+
44+
Returns:
45+
The adjusted model_directory that should be used in `test_core_models_in_directory`.
46+
"""
47+
import rasa.model
48+
49+
p = Path(model_directory)
50+
if p.is_file():
51+
if model_directory != rasa.model.get_latest_model():
52+
print_warning(
53+
"You passed a file as '--model'. Will use the directory containing this file instead."
54+
)
55+
model_directory = str(p.parent)
56+
57+
return model_directory
58+
59+
3060
def test_core_models(models: List[Text], stories: Text, output: Text):
3161
from rasa.core.test import compare_models
3262

tests/test_test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from pathlib import Path
2+
from _pytest.capture import CaptureFixture
3+
from _pytest.monkeypatch import MonkeyPatch
4+
5+
import rasa.model
6+
import rasa.cli.utils
7+
8+
9+
def monkeypatch_get_latest_model(tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
10+
latest_model = tmp_path / "my_test_model.tar.gz"
11+
monkeypatch.setattr(rasa.model, "get_latest_model", lambda: str(latest_model))
12+
13+
14+
def test_get_sanitized_model_directory_when_not_passing_model(
15+
capsys: CaptureFixture, tmp_path: Path, monkeypatch: MonkeyPatch
16+
):
17+
from rasa.test import _get_sanitized_model_directory
18+
19+
monkeypatch_get_latest_model(tmp_path, monkeypatch)
20+
21+
# Create a fake model on disk so that `is_file` returns `True`
22+
latest_model = Path(rasa.model.get_latest_model())
23+
latest_model.touch()
24+
25+
# Input: default model file
26+
# => Should return containing directory
27+
new_modeldir = _get_sanitized_model_directory(str(latest_model))
28+
captured = capsys.readouterr()
29+
assert not captured.out
30+
assert new_modeldir == str(latest_model.parent)
31+
32+
33+
def test_get_sanitized_model_directory_when_passing_model_file_explicitly(
34+
capsys: CaptureFixture, tmp_path: Path, monkeypatch: MonkeyPatch
35+
):
36+
from rasa.test import _get_sanitized_model_directory
37+
38+
monkeypatch_get_latest_model(tmp_path, monkeypatch)
39+
40+
other_model = tmp_path / "my_test_model1.tar.gz"
41+
assert str(other_model) != rasa.model.get_latest_model()
42+
other_model.touch()
43+
44+
# Input: some file
45+
# => Should return containing directory and print a warning
46+
new_modeldir = _get_sanitized_model_directory(str(other_model))
47+
captured = capsys.readouterr()
48+
assert captured.out
49+
assert new_modeldir == str(other_model.parent)
50+
51+
52+
def test_get_sanitized_model_directory_when_passing_other_input(
53+
capsys: CaptureFixture, tmp_path: Path, monkeypatch: MonkeyPatch
54+
):
55+
from rasa.test import _get_sanitized_model_directory
56+
57+
monkeypatch_get_latest_model(tmp_path, monkeypatch)
58+
59+
# Input: anything that is not an existing file
60+
# => Should return input
61+
modeldir = "random_dir"
62+
assert not Path(modeldir).is_file()
63+
new_modeldir = _get_sanitized_model_directory(modeldir)
64+
captured = capsys.readouterr()
65+
assert not captured.out
66+
assert new_modeldir == modeldir

0 commit comments

Comments
 (0)