Skip to content

Commit

Permalink
Merge branch 'kangaroo'
Browse files Browse the repository at this point in the history
# Conflicts:
#	easynlp/modelzoo/models/auto/configuration_auto.py
#	easynlp/modelzoo/models/auto/modeling_auto.py
#	easynlp/modelzoo/models/auto/tokenization_auto.py
  • Loading branch information
Rhea committed Jul 31, 2022
2 parents b023b4d + 62b2205 commit 66a0ede
Show file tree
Hide file tree
Showing 44 changed files with 7,094 additions and 17 deletions.
322 changes: 321 additions & 1 deletion easynlp/appzoo/language_modeling/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from tracemalloc import start
import torch
import json
import pandas as pd
import numpy as np
import copy

from ..dataset import BaseDataset
from ...modelzoo import AutoTokenizer
Expand Down Expand Up @@ -58,6 +61,8 @@ def __init__(self,

# DKPLM needs special tokens to recognize entity in input sentence
self.dkplm_model_prefix = True if 'dkplm' in pretrained_model_name_or_path else False
self.kangaroo_model_prefix = True if 'kangaroo' in pretrained_model_name_or_path else False

if self.dkplm_model_prefix:
entity_emb_file = user_defined_parameters.get('entity_emb_file', '')
rel_emb_file = user_defined_parameters.get('rel_emb_file', '')
Expand All @@ -83,6 +88,24 @@ def __init__(self,
self.entity_emb = entity_emb
self.rel_emb = rel_emb
self.tokenizer.add_special_tokens({'additional_special_tokens': ['[ENT]']})

if self.kangaroo_model_prefix:
entity_file = user_defined_parameters.get('entity_file', '')
rel_file = user_defined_parameters.get('rel_file', '')
CL_samples_file = user_defined_parameters.get('samples_file', '')
concept_emb_file = user_defined_parameters.get('concept_emb_file', '')
if entity_file == '' or rel_file == '':
raise ValueError('Kangaroo needs knowledge embedding file...')

rel_df = pd.read_csv(rel_file)
# entity_df = pd.read_csv(entity_file)[:500]
entity_df = pd.read_csv(entity_file)

# create entity tree
self.entity_tree, self.tokenid2entityid = self.kangaroo_create_entity_tree(entity_df)
self.tokenidVec, self.positionidVec = self.kangaroo_get_contrastive_samples(CL_samples_file)
self.conceptEmbVec = self.kangaroo_get_concept_emb(concept_emb_file)


def convert_single_row_to_example(self, row):
if self.dkplm_model_prefix:
Expand All @@ -95,6 +118,8 @@ def convert_single_row_to_example(self, row):
sentence_tokens, ent_pos = self.dkplm_row_data_process(text)
token_ids.extend(self.tokenizer.convert_tokens_to_ids(sentence_tokens))
ent_pos = [[item[0]+1, item[1]+1] for item in ent_pos]
elif self.kangaroo_model_prefix:
return self.kangaroo_row_data_process(row)
else:
text = json.loads(row.strip())['text']
token_ids = [self.cls_ids]
Expand All @@ -111,6 +136,38 @@ def convert_single_row_to_example(self, row):
return token_ids, mask_labels, mask_span_indices

def batch_fn(self, batch):
if self.kangaroo_model_prefix:
input_ids = [t[0] for t in batch]
attention_mask = [t[1] for t in batch]
label_ids = [t[2] for t in batch]
entities_position = [t[3] for t in batch]
ent_mask = [t[4] for t in batch]
sample_token_id = [t[5] for t in batch]
sample_position_id = [t[6] for t in batch]
sample_mask = [t[7] for t in batch]
concept_emb = [t[8] for t in batch]

# input_ids = [t[0] for t in batch]
# attention_mask = [t[1] for t in batch]
# label_ids = [t[2] for t in batch]
# entities_position = [t[3] for t in batch]
# ent_mask = [t[4] for t in batch]
# sample_token_id = [t[5] for t in batch]
# sample_position_id = [t[6] for t in batch]
# sample_mask = [t[7] for t in batch]
# concept_emb = [t[8] for t in batch]

return {
'input_ids': torch.LongTensor(input_ids),
'attention_mask': torch.LongTensor(attention_mask),
'label_ids': torch.LongTensor(label_ids),
'entities_position': torch.LongTensor(entities_position),
'ent_mask': torch.LongTensor(ent_mask),
'sample_token_id': torch.LongTensor(sample_token_id),
'sample_position_id': torch.LongTensor(sample_position_id),
'sample_mask': torch.LongTensor(sample_mask),
'concept_emb': torch.LongTensor(concept_emb)
}
token_ids = [t[0] for t in batch]
mask_labels = [t[1] for t in batch]
lengths = [len(t[0]) for t in batch]
Expand Down Expand Up @@ -323,4 +380,267 @@ def align_dkplm_input(self, max_seq_len, token_ids, ent_pos, relation_id, replac
# replcaed entity = entity + relation (TransE)
padded_replaced_entity_emb = padded_entity_emb + replaced_padded_rel_emb

return padded_insert_know_position_mask, padded_replaced_entity_emb, padded_rel_emb, padded_insert_know_labels
return padded_insert_know_position_mask, padded_replaced_entity_emb, padded_rel_emb, padded_insert_know_labels

def kangaroo_row_data_process(self, text, entity_num=3, entity_gap=5):

tokens = [t for t in text]
# tokens = self.tokenizer.tokenize([t for t in text])
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)

if len(token_ids) > self.max_seq_length - 2:
token_ids = token_ids[:(self.max_seq_length - 2)]

# entity position
entity_pos = []
i = 0
while i < len(token_ids):
search_result = self.entity_tree.search(token_ids, i)
if len(search_result) == 0:
i = i + 1
continue
j = search_result[-1]
entity = token_ids[i:j]
entity_pos.append((i, j))
i = j + 1

# entity id list
entities = [-100 for _ in range(len(token_ids))]
# entity id list, trasform id to 1,2,3...n
entities_position = [0 for _ in range(len(token_ids))]
entity_index = 0
entity_pos_true = []

entity_id_list = []
# 保存实体首、尾位置的index。[ent1_head_index, ent1_tail_index, ent2_head_index,...]
entity_head_tail = []

for pos in entity_pos:
close_flag = False
h_index = pos[0]
t_index = pos[1]
# 保证entity之间间隔>=entity_gap, debug check entity_pos是否从小到大顺序
for i in range(1, entity_gap+1):
if h_index - i < 0:
continue
if entities[h_index - i] != -100:
close_flag = True
if close_flag:
continue
entity_id = self.tokenid2entityid[str(token_ids[h_index:t_index])]
entity_index += 1
entity_pos_true.append(pos)

entity_id_list.append(entity_id)
entity_head_tail.extend([h_index, t_index - 1])

for ent_index in range(h_index, t_index):
entities[ent_index] = entity_id
entities_position[ent_index] = entity_index

if entity_index == entity_num:
break

if entity_index < entity_num:
for j in range(entity_num - entity_index):
entity_id_list.append(-1)
entity_head_tail.extend([-1, -1])

masked_tokens_id, masked_lm_labels = self.kangaroo_create_mask(token_ids, entity_pos_true)

input_mask = list((np.array(token_ids) != -1) * 1)

# input_ids = [self.cls_ids] + token_ids + [self.sep_ids]
masked_tokens_id = [self.cls_ids] + masked_tokens_id + [self.sep_ids]
entities = [-100] + entities + [-100]
# 检查entities_position第0位要不要非0
entities_position = [0] + entities_position + [0]
masked_lm_labels = [-100] + masked_lm_labels + [-100]
input_mask = [1] + input_mask + [1]

if len(masked_tokens_id) < self.max_seq_length:
rest = self.max_seq_length - len(masked_tokens_id)
masked_tokens_id.extend([0] * rest)
entities.extend([-100] * rest)
entities_position.extend([0] * rest)
masked_lm_labels.extend([-100] * rest)
input_mask.extend([0] * rest)

# 补全padding
assert len(masked_tokens_id) == len(entities_position)
assert len(entities_position) == len(masked_lm_labels)

# masked_tokens_id = torch.LongTensor(masked_tokens_id)
# input_mask = torch.LongTensor(input_mask)
# masked_lm_labels = torch.LongTensor(masked_lm_labels)
entities_position = torch.LongTensor(entities_position)

ent_mask = torch.LongTensor((entities_position != 0) * 1)
entity_id_index = torch.LongTensor(entity_id_list) + 1
sample_token_id = self.tokenidVec[entity_id_index]
sample_position_id = self.positionidVec[entity_id_index]
sample_mask = torch.LongTensor((np.array(sample_token_id) != 0) * 1)
concept_emb = self.conceptEmbVec[entity_id_index] # [batch_size, entity_num, concept_size]

return masked_tokens_id, input_mask, masked_lm_labels, entities_position.tolist(), ent_mask.tolist(), sample_token_id.tolist(), sample_position_id.tolist(), sample_mask.tolist(), concept_emb.tolist()

def kangaroo_create_mask(self, tokens_id, entity_pos_true, entity_gap=5):

entity_prop = 0.1
masked_lm_labels = [-100 for _ in range(len(tokens_id))]
masked_tokens_id = copy.deepcopy(tokens_id)

input_len = len(tokens_id)
entities_length = np.sum([j - i for (i, j) in entity_pos_true])

while entities_length / input_len > entity_prop:
del entity_pos_true[random.randint(0, len(entity_pos_true) - 1)]
entities_length = np.sum([j - i for (i, j) in entity_pos_true])

entity_probability = entities_length / input_len
# 考虑entity 前后距离较近的不进行mlm
mlm_token_probability = (self.mlm_mask_prop - entity_probability) * input_len / (input_len - 7 * len(entity_pos_true))

# entity masking
token_mlm_flag = [1 for _ in range(len(tokens_id))]
for po in entity_pos_true:
masked_lm_labels[po[0]:po[1]] = tokens_id[po[0]:po[1]]
masked_tokens_id[po[0]:po[1]] = [self.mask_idx] * (po[1] - po[0])
if po[0] - entity_gap < 0:
s_index = 0
else:
s_index = po[0] - entity_gap
if po[1] + entity_gap > len(tokens_id):
e_index = len(tokens_id)
else:
e_index = po[1] + entity_gap
token_mlm_flag[s_index: e_index] = [0] * (e_index - s_index)

# token masking
for ind in range(len(token_mlm_flag)):
if token_mlm_flag[ind] == 0:
continue

if random.random() > mlm_token_probability:
continue

if random.random() < 0.8:
masked_tokens_id[ind] = self.mask_idx
else:
if random.random() < 0.5:
masked_tokens_id[ind] = tokens_id[ind]
else:
masked_tokens_id[ind] = random.randint(0, self.vocab_size - 1)

masked_lm_labels[ind] = tokens_id[ind]

return masked_tokens_id, masked_lm_labels

def kangaroo_create_entity_tree(self, entity_df):
full_name_to_id = {}
for i in range(len(entity_df)):
full_name = entity_df.iloc[i]['main_name']
name_list = entity_df.iloc[i]['name_list'].split('|')
if pd.isna(full_name):
name_list = entity_df.iloc[i]['name_list'].split('|')
id = int(entity_df.iloc[i]['index'])
for name in name_list:
full_name_to_id[name] = id

entities = list(full_name_to_id.keys())
entities_tokens_id = []
tokenid2entityid = {}
for entity in entities:
entity_token_id = self.tokenizer.convert_tokens_to_ids([k for k in entity])
entities_tokens_id.append(entity_token_id)
tokenid2entityid[str(entity_token_id)] = full_name_to_id[entity]
entity_tree = KangarooTrieTree()
for word in entities_tokens_id:
entity_tree.add_word(word)
return entity_tree, tokenid2entityid

def kangaroo_get_contrastive_samples(self, samples_file, max_level=4):
samples = np.load(samples_file, allow_pickle=True).item()
max_index = np.max(list(samples.keys()))
token_id_vec = [[[0 for _ in range(self.max_seq_length)] for _ in range(max_level)] for _ in
range(max_index + 2)]
pos_id_vec = [[[0 for _ in range(self.max_seq_length)] for _ in range(max_level)] for _ in range(max_index + 2)]
# for ind in random.sample(samples.keys(), 500):
for ind in samples.keys():
try:
token_id_list = []
pos_id_list = []
for le in range(1, max_level + 1):
level = "level_%d" % le
if len(samples[ind][level]) == 0:
level = "level_2"
tokens = samples[ind][level][0]['tokens']
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
pos_ids = samples[ind][level][0]['position_id']
# assert len(token_ids) == len(pos_ids)

if len(token_ids) < self.max_seq_length:
token_ids.extend([0] * (self.max_seq_length - len(token_ids)))
pos_ids.extend([0] * (self.max_seq_length - len(pos_ids)))

token_id_list.append(token_ids)
pos_id_list.append(pos_ids)

token_id_vec[ind + 1] = token_id_list
pos_id_vec[ind + 1] = pos_id_list
except:
continue

return torch.LongTensor(token_id_vec), torch.LongTensor(pos_id_vec)

def kangaroo_get_concept_emb(self, emb_file, dim=100):
entity2emb = np.load(emb_file, allow_pickle=True).item()
max_index = np.max(list(entity2emb.keys()))
concept_emb_vec = [[0 for _ in range(dim)] for _ in range(int(max_index) + 2)]
for ind in entity2emb.keys():
concept_emb_vec[int(ind) + 1] = entity2emb[ind]
return torch.FloatTensor(concept_emb_vec)



class KangarooTrieTree:
"""
Construct entity prefix structure for KANGAROO
"""
def __init__(self):
self.node = [""]
self.edge = [{}]
self.flag = [False]

def add_node(self, node):
self.node.append(node)
self.edge.append({})
self.flag.append(False)
return len(self.node) - 1

def add_word(self, word):
u = 0
for i in word:
if i not in self.edge[u]:
self.edge[u][i] = self.add_node(i)
u = self.edge[u][i]
self.flag[u] = True

def show(self):
for i in range(len(self.node)):
print(i)
print(self.node[i])
print(self.edge[i])
print(self.flag[i])
print()

def search(self, sentence, start_position):
i = start_position
u = 0
result = []
while i < len(sentence) and sentence[i] in self.edge[u]:
u = self.edge[u][sentence[i]]
i += 1
if self.flag[u]:
result.append(i)
return result
Loading

0 comments on commit 66a0ede

Please sign in to comment.