6
6
import numpy as np
7
7
import typing
8
8
from typing import Any
9
+ import os
9
10
10
11
from rasa_nlu .featurizers import Featurizer
11
12
from rasa_nlu .training_data import Message
@@ -49,15 +50,26 @@ def __init__(self, component_config=None):
49
50
# makes sure the name of the configuration is part of the config
50
51
# this is important for e.g. persistence
51
52
component_config ["name" ] = self .name
52
- print ("hi" )
53
53
self .component_config = config .override_defaults (
54
54
self .defaults , component_config )
55
55
56
56
self .partial_processing_pipeline = None
57
57
self .partial_processing_context = None
58
58
self .layer_indexes = [- 2 ]
59
- bert_config = modeling .BertConfig .from_json_file ("/Users/oakela/Documents/RASA/bert/uncased_L-24_H-1024_A-16/bert_config.json" )
60
- self .tokenizer = tokenization .FullTokenizer (vocab_file = "/Users/oakela/Documents/RASA/bert/uncased_L-24_H-1024_A-16/vocab.txt" , do_lower_case = True )
59
+
60
+ model_dir = component_config .get ("model_dir" )
61
+ print ("Loading model from" , model_dir )
62
+
63
+ dir_files = os .listdir (model_dir )
64
+
65
+ if all (file not in dir_files for file in ('bert_config.json' , 'vocab.txt' )):
66
+ raise Exception ("To use BertFeaturizer you need to specify a "
67
+ "directory path to a pre-trained model, i.e. "
68
+ "containing the files 'bert_config.json', "
69
+ "'vocab.txt' and model checkpoint" )
70
+
71
+ bert_config = modeling .BertConfig .from_json_file (os .path .join (model_dir , "bert_config.json" ))
72
+ self .tokenizer = tokenization .FullTokenizer (vocab_file = os .path .join (model_dir , "vocab.txt" ), do_lower_case = True )
61
73
is_per_host = tf .contrib .tpu .InputPipelineConfig .PER_HOST_V2
62
74
run_config = tf .contrib .tpu .RunConfig (
63
75
master = None ,
@@ -66,7 +78,7 @@ def __init__(self, component_config=None):
66
78
per_host_input_for_training = is_per_host ))
67
79
model_fn = model_fn_builder (
68
80
bert_config = bert_config ,
69
- init_checkpoint = "/Users/oakela/Documents/RASA/bert/uncased_L-24_H-1024_A-16/ bert_model.ckpt" ,
81
+ init_checkpoint = os . path . join ( model_dir , " bert_model.ckpt") ,
70
82
layer_indexes = self .layer_indexes ,
71
83
use_tpu = False ,
72
84
use_one_hot_embeddings = False )
0 commit comments