Skip to content

Commit

Permalink
new change fr=or metric
Browse files Browse the repository at this point in the history
  • Loading branch information
qbetterk committed Jan 28, 2019
1 parent 5ff6458 commit 1df5afd
Show file tree
Hide file tree
Showing 3 changed files with 354 additions and 44 deletions.
136 changes: 128 additions & 8 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class _Config:
def __init__(self):
self._init_logging_handler()
self.cuda_device = 0
self.cuda_device = 7
self.eos_m_token = 'EOS_M'
self.beam_len_bonus = 0.5

Expand All @@ -21,7 +21,6 @@ def init_handler(self, m):
'tsdf-kvret':self._kvret_tsdf_init
}
init_method[m]()

def _camrest_tsdf_init(self):
self.beam_len_bonus = 0.5
self.prev_z_method = 'separate'
Expand All @@ -31,10 +30,94 @@ def _camrest_tsdf_init(self):
self.split = (3, 1, 1)
self.lr = 0.003
self.lr_decay = 0.5
self.vocab_path = './vocab/vocab-camrest.pkl'
self.data = './data/CamRest676/CamRest676.json'
self.entity = './data/CamRest676/CamRestOTGY.json'
self.db = './data/CamRest676/CamRestDB.json'
self.vocab_path = './vocab/vocab-simdial.pkl'
self.data = [
# camrest
'./data/CamRest676/CamRest676.json',
# restaurant
'../SimDial/data/restaurant-CleanSpec-1000.json',
# # restaurant noised
# '../SimDial/data/restaurant-MixSpec-1000.json',
# # bus
# '../SimDial/data/bus-MixSpec-1000.json',
# # weather
# '../SimDial/data/weather-MixSpec-1000.json',
# # movie
# '../SimDial/data/movie-MixSpec-1000.json',


# # restaurant Pitt
# '../SimDial/data/rest_pitt-MixSpec-1000.json',
# # restaurant style
# '../SimDial/data/restaurant_style-MixSpec-1000.json'
]
#
self.db = [
# camrest
'./data/CamRest676/CamRestDB.json',
# restaurant
'../SimDial/data/restaurant-CleanSpec-1000-DB.json',
# restaurant noised
'../SimDial/data/restaurant-MixSpec-1000-DB.json',
# bus
'../SimDial/data/bus-MixSpec-1000-DB.json',
# weather
'../SimDial/data/weather-MixSpec-1000-DB.json',
# movie
'../SimDial/data/movie-MixSpec-1000-DB.json',
# restaurant Pitt
'../SimDial/data/rest_pitt-MixSpec-1000-DB.json',
# restaurant style
'../SimDial/data/restaurant_style-MixSpec-1000-DB.json'
]

# self.entity = './data/CamRest676/CamRestOTGY.json'
# self.entity = '../SimDial/train/simdialOTGY.json'
self.entity = [
# camrest
'./data/CamRest676/CamRestOTGY.json',
# restaurant
'../SimDial/data/restaurant-CleanSpec-1000-OTGY.json',
# restaurant noised
'../SimDial/data/restaurant-MixSpec-1000-OTGY.json',
# bus
'../SimDial/data/bus-MixSpec-1000-OTGY.json',
# weather
'../SimDial/data/weather-MixSpec-1000-OTGY.json',
# movie
'../SimDial/data/movie-MixSpec-1000-OTGY.json',
# restaurant Pitt
'../SimDial/data/rest_pitt-MixSpec-1000-OTGY.json',
# restaurant style
'../SimDial/data/restaurant_style-MixSpec-1000-OTGY.json'
]

# # # added data for maml
# # restaurant
# self.data_sim1 = '../SimDial/data/restaurant-MixSpec-1000.json'
# self.db_sim1 = '../SimDial/data/restaurant-MixSpec-1000-DB.json'

# # restaurant style
# self.data_sim2 = '../SimDial/data/restaurant_style-MixSpec-1000.json'
# self.db_sim2 = '../SimDial/data/restaurant_style-MixSpec-1000-DB.json'

# # bus
# self.data_sim3 = '../SimDial/data/bus-MixSpec-1000.json'
# self.db_sim3 = '../SimDial/data/bus-MixSpec-1000-DB.json'

# # weather
# self.data_sim4 = '../SimDial/data/weather-MixSpec-1000.json'
# self.db_sim4 = '../SimDial/data/weather-MixSpec-1000-DB.json'

# # movie
# self.data_sim5 = '../SimDial/data/movie-MixSpec-1000.json'
# self.db_sim5 = '../SimDial/data/movie-MixSpec-1000-DB.json'

# # restaurant Pitt
# self.data_sim6 = '../SimDial/data/rest_pitt-MixSpec-1000.json'
# self.db_sim6 = '../SimDial/data/rest_pitt-MixSpec-1000-DB.json'


self.glove_path = './data/glove/glove.6B.50d.txt'
self.batch_size = 32
self.z_length = 8
Expand All @@ -48,8 +131,8 @@ def _camrest_tsdf_init(self):
self.max_ts = 40
self.early_stop_count = 3
self.new_vocab = True
self.model_path = './models/camrest.pkl'
self.result_path = './results/camrest-rl.csv'
self.model_path = './models/simdial1.pkl'
self.result_path = './results/simdial1.csv'
self.teacher_force = 100
self.beam_search = False
self.beam_size = 10
Expand All @@ -59,6 +142,43 @@ def _camrest_tsdf_init(self):
self.truncated = False
self.pretrain = False

# def _camrest_tsdf_init(self):
# self.beam_len_bonus = 0.5
# self.prev_z_method = 'separate'
# self.vocab_size = 800
# self.embedding_size = 50
# self.hidden_size = 50
# self.split = (3, 1, 1)
# self.lr = 0.003
# self.lr_decay = 0.5
# self.vocab_path = './vocab/vocab-camrest.pkl'
# self.data = './data/CamRest676/CamRest676.json'
# self.entity = './data/CamRest676/CamRestOTGY.json'
# self.db = './data/CamRest676/CamRestDB.json'
# self.glove_path = './data/glove/glove.6B.50d.txt'
# self.batch_size = 32
# self.z_length = 8
# self.degree_size = 5
# self.layer_num = 1
# self.dropout_rate = 0.5
# self.epoch_num = 100 # triggered by early stop
# self.rl_epoch_num = 2
# self.cuda = True
# self.spv_proportion = 100
# self.max_ts = 40
# self.early_stop_count = 3
# self.new_vocab = True
# self.model_path = './models/camrest.pkl'
# self.result_path = './results/camrest-rl.csv'
# self.teacher_force = 100
# self.beam_search = False
# self.beam_size = 10
# self.sampling = False
# self.unfrz_attn_epoch = 0
# self.skip_unsup = False
# self.truncated = False
# self.pretrain = False

def _kvret_tsdf_init(self):
self.prev_z_method = 'separate'
self.intent = 'all'
Expand Down
66 changes: 59 additions & 7 deletions metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import functools
import pickle
from reader import clean_replace
from config import global_config as cfg
import pdb

en_sws = set(stopwords.words())
wn = WordNetLemmatizer()
Expand Down Expand Up @@ -202,6 +204,7 @@ def __init__(self, result_path):
self.entity_dict = {}

def run_metrics(self):
# # to go over all the OTGY files.
raw_json = open('./data/CamRest676/CamRest676.json')
raw_entities = open('./data/CamRest676/CamRestOTGY.json')
raw_data = json.loads(raw_json.read().lower())
Expand All @@ -214,14 +217,54 @@ def run_metrics(self):
bleu_score = self.bleu_metric(data,'bleu')
success_f1 = self.success_f1_metric(data, 'success')
match = self.match_metric(data, 'match', raw_data=raw_data)
# #############################
# pdb.set_trace()
# #############################
self._print_dict(self.metric_dict)
return -success_f1[0]

def get_entities(self, entity_data):
for k in entity_data['informable']:
self.entities.extend(entity_data['informable'][k])
for item in entity_data['informable'][k]:
self.entity_dict[item] = k
def run_metrics_maml(self):
# # to go over all the OTGY files.

# # for the enitites
raw_entities_path = cfg.entity
raw_entities = []
for entity_path in raw_entities_path:
raw_entity = open(entity_path)
raw_entities.append(json.loads(raw_entity.read().lower()))
# #############################
# pdb.set_trace()
# #############################
self.get_entities(raw_entities)

# # for the data
raw_json_path = cfg.data
# raw_json = open('./data/CamRest676/CamRest676.json')
raw_json = open(cfg.data[1])
raw_data = json.loads(raw_json.read().lower())
data = self.read_result_data()
for i, row in enumerate(data):
data[i]['response'] = self.clean(data[i]['response'])
data[i]['generated_response'] = self.clean(data[i]['generated_response'])

bleu_score = self.bleu_metric(data,'bleu')
success_f1 = self.success_f1_metric(data, 'success')
match = self.match_metric(data, 'match', raw_data=raw_data)
# #############################
# pdb.set_trace()
# #############################
self._print_dict(self.metric_dict)
return -success_f1[0]

def get_entities(self, entities_data):
for entity_data in entities_data:
# #############################
# pdb.set_trace()
# #############################
for k in entity_data['informable']:
self.entities.extend(entity_data['informable'][k])
for item in entity_data['informable'][k]:
self.entity_dict[item] = k

def _extract_constraint(self, z):
z = z.split()
Expand All @@ -240,7 +283,11 @@ def _extract_constraint(self, z):

def _extract_request(self, z):
z = z.split()
return set(z).intersection(['address', 'postcode', 'phone', 'area', 'pricerange','food'])
return set(z).intersection(['address', 'postcode', 'phone', 'area', 'pricerange','food',
'open', 'price', 'parking',
'duration', 'arrive_in',
'temperature', 'weather_type',
'rating', 'company', 'director'])

@report
def match_metric(self, data, sub='match',raw_data=None):
Expand Down Expand Up @@ -271,16 +318,21 @@ def match_metric(self, data, sub='match',raw_data=None):
for idx, w in enumerate(response_token):
if w.endswith('SLOT') and w != 'SLOT':
truth_response_req.append(w.split('_')[0])
# #############################
# pdb.set_trace()
# #############################
if not gen_cons:
gen_bspan = dial[-1]['generated_bspan']
gen_cons = self._extract_constraint(gen_bspan)
if truth_cons:
# #############################
# pdb.set_trace()
# #############################
if gen_cons == truth_cons:
match += 1
else:
print(gen_cons, truth_cons)
total += 1

return match / total, success / total

@report
Expand Down
Loading

0 comments on commit 1df5afd

Please sign in to comment.