Skip to content

Commit

Permalink
fix some typo
Browse files Browse the repository at this point in the history
  • Loading branch information
shizhediao committed Jul 18, 2018
1 parent cbb57e3 commit e565527
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,15 @@ def match_metric(self, data, sub='match',raw_data=None):
for dial_id in dials:
truth_req, gen_req = [], []
dial = dials[dial_id]
gen_bpsan, truth_cons, gen_cons = None, None, set()
gen_bspan, truth_cons, gen_cons = None, None, set()
truth_turn_num = -1
truth_response_req = []
for turn_num,turn in enumerate(dial):
if 'SLOT' in turn['generated_response']:
gen_bpsan = turn['generated_bpsan']
gen_cons = self._extract_constraint(gen_bpsan)
gen_bspan = turn['generated_bspan']
gen_cons = self._extract_constraint(gen_bspan)
if 'SLOT' in turn['response']:
truth_cons = self._extract_constraint(turn['bpsan'])
truth_cons = self._extract_constraint(turn['bspan'])
gen_response_token = turn['generated_response'].split()
response_token = turn['response'].split()
for idx, w in enumerate(gen_response_token):
Expand All @@ -272,8 +272,8 @@ def match_metric(self, data, sub='match',raw_data=None):
if w.endswith('SLOT') and w != 'SLOT':
truth_response_req.append(w.split('_')[0])
if not gen_cons:
gen_bpsan = dial[-1]['generated_bpsan']
gen_cons = self._extract_constraint(gen_bpsan)
gen_bspan = dial[-1]['generated_bspan']
gen_cons = self._extract_constraint(gen_bspan)
if truth_cons:
if gen_cons == truth_cons:
match += 1
Expand Down Expand Up @@ -413,30 +413,30 @@ def _get_entity_dict(self, entity_data):
self.entity_dict = entity_dict

@report
def match_rate_metric(self, data, sub='match',bpsans='./data/kvret/test.bpsan.pkl'):
def match_rate_metric(self, data, sub='match',bspans='./data/kvret/test.bspan.pkl'):
dials = self.pack_dial(data)
match,total = 0,1e-8
bpsan_data = pickle.load(open(bpsans,'rb'))
bspan_data = pickle.load(open(bspans,'rb'))
# find out the last placeholder and see whether that is correct
# if no such placeholder, see the final turn, because it can be a yes/no question or scheduling conversation
for dial_id in dials:
dial = dials[dial_id]
gen_bpsan, truth_cons, gen_cons = None, None, set()
gen_bspan, truth_cons, gen_cons = None, None, set()
truth_turn_num = -1
for turn_num,turn in enumerate(dial):
if 'SLOT' in turn['generated_response']:
gen_bpsan = turn['generated_bpsan']
gen_cons = self._extract_constraint(gen_bpsan)
gen_bspan = turn['generated_bspan']
gen_cons = self._extract_constraint(gen_bspan)
if 'SLOT' in turn['response']:
truth_cons = self._extract_constraint(turn['bpsan'])
truth_cons = self._extract_constraint(turn['bspan'])

# KVRET dataset includes "scheduling" (so often no SLOT decoded in ground truth)
if not truth_cons:
truth_bpsan = dial[-1]['bpsan']
truth_cons = self._extract_constraint(truth_bpsan)
truth_bspan = dial[-1]['bspan']
truth_cons = self._extract_constraint(truth_bspan)
if not gen_cons:
gen_bpsan = dial[-1]['generated_bpsan']
gen_cons = self._extract_constraint(gen_bpsan)
gen_bspan = dial[-1]['generated_bspan']
gen_cons = self._extract_constraint(gen_bspan)

if truth_cons:
if self.constraint_same(gen_cons, truth_cons):
Expand Down

0 comments on commit e565527

Please sign in to comment.