Skip to content

Commit 795b605

Browse files
committed
add configurable model_dir param
1 parent 1d2160c commit 795b605

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

rasa_nlu/featurizers/bert_featurizer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import typing
88
from typing import Any
9+
import os
910

1011
from rasa_nlu.featurizers import Featurizer
1112
from rasa_nlu.training_data import Message
@@ -49,15 +50,26 @@ def __init__(self, component_config=None):
4950
# makes sure the name of the configuration is part of the config
5051
# this is important for e.g. persistence
5152
component_config["name"] = self.name
52-
print("hi")
5353
self.component_config = config.override_defaults(
5454
self.defaults, component_config)
5555

5656
self.partial_processing_pipeline = None
5757
self.partial_processing_context = None
5858
self.layer_indexes = [-2]
59-
bert_config = modeling.BertConfig.from_json_file("/Users/oakela/Documents/RASA/bert/uncased_L-24_H-1024_A-16/bert_config.json")
60-
self.tokenizer = tokenization.FullTokenizer(vocab_file="/Users/oakela/Documents/RASA/bert/uncased_L-24_H-1024_A-16/vocab.txt", do_lower_case=True)
59+
60+
model_dir = component_config.get("model_dir")
61+
print("Loading model from", model_dir)
62+
63+
dir_files = os.listdir(model_dir)
64+
65+
if all(file not in dir_files for file in ('bert_config.json', 'vocab.txt')):
66+
raise Exception("To use BertFeaturizer you need to specify a "
67+
"directory path to a pre-trained model, i.e. "
68+
"containing the files 'bert_config.json', "
69+
"'vocab.txt' and model checkpoint")
70+
71+
bert_config = modeling.BertConfig.from_json_file(os.path.join(model_dir, "bert_config.json"))
72+
self.tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True)
6173
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
6274
run_config = tf.contrib.tpu.RunConfig(
6375
master=None,
@@ -66,7 +78,7 @@ def __init__(self, component_config=None):
6678
per_host_input_for_training=is_per_host))
6779
model_fn = model_fn_builder(
6880
bert_config=bert_config,
69-
init_checkpoint="/Users/oakela/Documents/RASA/bert/uncased_L-24_H-1024_A-16/bert_model.ckpt",
81+
init_checkpoint=os.path.join(model_dir, "bert_model.ckpt"),
7082
layer_indexes=self.layer_indexes,
7183
use_tpu=False,
7284
use_one_hot_embeddings=False)

sample_configs/config_bert.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
language: "en"
2+
3+
pipeline:
4+
- name: "intent_featurizer_bert"
5+
model_dir: "/path/ending/with/uncased_L-24_H-1024_A-16/"
6+
- name: "intent_classifier_sklearn"

0 commit comments

Comments
 (0)