diff --git a/rasa_core/train.py b/rasa_core/train.py index f7f9e1c72d3..a18aaff7390 100644 --- a/rasa_core/train.py +++ b/rasa_core/train.py @@ -5,9 +5,12 @@ import argparse import logging + from builtins import str from rasa_core.agent import Agent +from rasa_core.channels.console import ConsoleInputChannel +from rasa_core.interpreter import RasaNLUInterpreter, RegexInterpreter from rasa_core.policies.keras_policy import KerasPolicy from rasa_core.policies.memoization import MemoizationPolicy @@ -54,8 +57,7 @@ def create_argument_parser(): parser.add_argument( '--online', default=False, - action='store_const', - const=True, + action='store_true', help="enable online training") parser.add_argument( '--augmentation', @@ -66,16 +68,13 @@ def create_argument_parser(): def train_dialogue_model(domain_file, stories_file, output_path, - online, nlu_model_path, kwargs): + use_online_learning, nlu_model_path, kwargs): agent = Agent(domain_file, policies=[MemoizationPolicy(), KerasPolicy()]) - if online: - from rasa_core.channels.console import ConsoleInputChannel - from rasa_core.interpreter import RasaNLUInterpreter + if use_online_learning: if nlu_model_path: agent.interpreter = RasaNLUInterpreter(nlu_model_path) else: - from rasa_core.interpreter import RegexInterpreter agent.interpreter = RegexInterpreter() agent.train_online( stories_file,