Skip to content

Commit 6b9016c

Browse files
authored
Merge pull request RasaHQ#4696 from RasaHQ/export-nlu-as-json-from-interactive
Export nlu as json from interactive
2 parents dc43187 + c8146d5 commit 6b9016c

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Removed
2323

2424
Fixed
2525
-----
26+
- Fixed exporting NLU training data in ``json`` format from ``rasa interactive``
2627

2728
[1.4.3] - 2019-10-29
2829
^^^^^^^^^^^^^^^^^^^^

rasa/core/training/interactive.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
from aiohttp import ClientError
1313
from colorclass import Color
14+
from rasa.nlu.training_data.loading import MARKDOWN, RASA
1415
from sanic import Sanic, response
1516
from sanic.exceptions import NotFound
1617
from terminaltables import AsciiTable, SingleTable
@@ -680,7 +681,7 @@ def _request_export_info() -> Tuple[Text, Text, Text]:
680681
"merge learned data with previous training examples)",
681682
default=PATHS["nlu"],
682683
validate=io_utils.file_type_validator(
683-
[".md"],
684+
[".md", ".json"],
684685
"Please provide a valid export path for the NLU data, e.g. 'nlu.md'.",
685686
),
686687
),
@@ -831,21 +832,29 @@ async def _write_nlu_to_file(
831832

832833
# need to guess the format of the file before opening it to avoid a read
833834
# in a write
834-
if loading.guess_format(export_nlu_path) in {"md", "unk"}:
835-
fformat = "md"
836-
else:
837-
fformat = "json"
838-
839-
if fformat == "md":
835+
nlu_format = _get_nlu_target_format(export_nlu_path)
836+
if nlu_format == MARKDOWN:
840837
stringified_training_data = nlu_data.nlu_as_markdown()
841838
else:
842839
stringified_training_data = nlu_data.nlu_as_json()
843840

844841
io_utils.write_text_file(stringified_training_data, export_nlu_path)
845842

846843

844+
def _get_nlu_target_format(export_path: Text) -> Text:
845+
guessed_format = loading.guess_format(export_path)
846+
847+
if guessed_format not in {MARKDOWN, RASA}:
848+
if export_path.endswith(".json"):
849+
guessed_format = RASA
850+
else:
851+
guessed_format = MARKDOWN
852+
853+
return guessed_format
854+
855+
847856
def _entities_from_messages(messages):
848-
"""Return all entities that occur in atleast one of the messages."""
857+
"""Return all entities that occur in at least one of the messages."""
849858
return list({e["entity"] for m in messages for e in m.data.get("entities", [])})
850859

851860

tests/core/test_interactive.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import json
2+
from typing import Text
3+
24
import pytest
35
import uuid
46
from aioresponses import aioresponses
57

68
import rasa.utils.io
79
from rasa.core.events import BotUttered
810
from rasa.core.training import interactive
11+
from rasa.nlu.training_data.loading import RASA, MARKDOWN
912
from rasa.utils.endpoints import EndpointConfig
1013
from rasa.core.actions.action import default_actions
1114
from rasa.core.domain import Domain
@@ -343,3 +346,11 @@ async def test_filter_intents_before_save_nlu_file():
343346
msgs.append(Message("/" + choice(intents), greet))
344347

345348
assert test_msgs == interactive._filter_messages(msgs)
349+
350+
351+
@pytest.mark.parametrize(
352+
"path, expected_format",
353+
[("bla.json", RASA), ("other.md", MARKDOWN), ("unknown", MARKDOWN)],
354+
)
355+
def test_get_nlu_target_format(path: Text, expected_format: Text):
356+
assert interactive._get_nlu_target_format(path) == expected_format

0 commit comments

Comments
 (0)