@@ -432,6 +432,42 @@ def test_train_core_success(
432
432
assert os .path .exists (os .path .join (model_path , "fingerprint.json" ))
433
433
434
434
435
+ def test_train_with_retrieval_events_success (rasa_app , default_stack_config ):
436
+ with ExitStack () as stack :
437
+ domain_file = stack .enter_context (
438
+ open ("data/test_domains/default_retrieval_intents.yml" )
439
+ )
440
+ config_file = stack .enter_context (open (default_stack_config ))
441
+ core_file = stack .enter_context (
442
+ open ("data/test_stories/stories_retrieval_intents.md" )
443
+ )
444
+ responses_file = stack .enter_context (open ("data/test_responses/default.md" ))
445
+ nlu_file = stack .enter_context (
446
+ open ("data/test_nlu/default_retrieval_intents.md" )
447
+ )
448
+
449
+ payload = dict (
450
+ domain = domain_file .read (),
451
+ config = config_file .read (),
452
+ stories = core_file .read (),
453
+ responses = responses_file .read (),
454
+ nlu = nlu_file .read (),
455
+ )
456
+
457
+ _ , response = rasa_app .post ("/model/train" , json = payload )
458
+ assert response .status == 200
459
+
460
+ # save model to temporary file
461
+ tempdir = tempfile .mkdtemp ()
462
+ model_path = os .path .join (tempdir , "model.tar.gz" )
463
+ with open (model_path , "wb" ) as f :
464
+ f .write (response .body )
465
+
466
+ # unpack model and ensure fingerprint is present
467
+ model_path = unpack_model (model_path )
468
+ assert os .path .exists (os .path .join (model_path , "fingerprint.json" ))
469
+
470
+
435
471
def test_train_missing_config (rasa_app : SanicTestClient ):
436
472
payload = dict (domain = "domain data" , config = None )
437
473
0 commit comments