|
11 | 11 | import numpy as np
|
12 | 12 | from aiohttp import ClientError
|
13 | 13 | from colorclass import Color
|
| 14 | +from rasa.nlu.training_data.loading import MARKDOWN, RASA |
14 | 15 | from sanic import Sanic, response
|
15 | 16 | from sanic.exceptions import NotFound
|
16 | 17 | from terminaltables import AsciiTable, SingleTable
|
@@ -680,7 +681,7 @@ def _request_export_info() -> Tuple[Text, Text, Text]:
|
680 | 681 | "merge learned data with previous training examples)",
|
681 | 682 | default=PATHS["nlu"],
|
682 | 683 | validate=io_utils.file_type_validator(
|
683 |
| - [".md"], |
| 684 | + [".md", ".json"], |
684 | 685 | "Please provide a valid export path for the NLU data, e.g. 'nlu.md'.",
|
685 | 686 | ),
|
686 | 687 | ),
|
@@ -831,21 +832,29 @@ async def _write_nlu_to_file(
|
831 | 832 |
|
832 | 833 | # need to guess the format of the file before opening it to avoid a read
|
833 | 834 | # in a write
|
834 |
| - if loading.guess_format(export_nlu_path) in {"md", "unk"}: |
835 |
| - fformat = "md" |
836 |
| - else: |
837 |
| - fformat = "json" |
838 |
| - |
839 |
| - if fformat == "md": |
| 835 | + nlu_format = _get_nlu_target_format(export_nlu_path) |
| 836 | + if nlu_format == MARKDOWN: |
840 | 837 | stringified_training_data = nlu_data.nlu_as_markdown()
|
841 | 838 | else:
|
842 | 839 | stringified_training_data = nlu_data.nlu_as_json()
|
843 | 840 |
|
844 | 841 | io_utils.write_text_file(stringified_training_data, export_nlu_path)
|
845 | 842 |
|
846 | 843 |
|
| 844 | +def _get_nlu_target_format(export_path: Text) -> Text: |
| 845 | + guessed_format = loading.guess_format(export_path) |
| 846 | + |
| 847 | + if guessed_format not in {MARKDOWN, RASA}: |
| 848 | + if export_path.endswith(".json"): |
| 849 | + guessed_format = RASA |
| 850 | + else: |
| 851 | + guessed_format = MARKDOWN |
| 852 | + |
| 853 | + return guessed_format |
| 854 | + |
| 855 | + |
847 | 856 | def _entities_from_messages(messages):
|
848 |
| - """Return all entities that occur in atleast one of the messages.""" |
| 857 | + """Return all entities that occur in at least one of the messages.""" |
849 | 858 | return list({e["entity"] for m in messages for e in m.data.get("entities", [])})
|
850 | 859 |
|
851 | 860 |
|
|
0 commit comments