Skip to content

Commit

Permalink
Added some annotation scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewJGaut committed Mar 18, 2023
1 parent 96813fe commit 5e10174
Show file tree
Hide file tree
Showing 11 changed files with 542 additions and 54 deletions.
Binary file added .DS_Store
Binary file not shown.
6 changes: 2 additions & 4 deletions new-contrast/data_provider/pretrain_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@ def __init__(
num_workers: int = 0,
batch_size: int = 256,
root: str = 'data/',
text_max_len: int = 128,
text_max_len: int = 512,
graph_aug1: str = 'dnodes',
graph_aug2: str = 'subgraph',
sampling_type: str = 'random',
sampling_temp: float = .1,
sampling_eps: float = .5,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset = GINPretrainDataset(root, text_max_len, graph_aug1, graph_aug2, sampling_type, sampling_temp, sampling_eps)
self.dataset = GINPretrainDataset(root, text_max_len, graph_aug1, graph_aug2, sampling_type)

def setup(self, stage: str = None):
self.train_dataset = self.dataset
Expand Down
56 changes: 7 additions & 49 deletions new-contrast/data_provider/pretrain_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __getitem__(self, index):
break
# print(text_list)
if len(text_list) < 2:
two_text_list = [text_list[0], text_list[0][-self.text_max_len:]]
two_text_list = [text_list[0], text_list[0][self.text_max_len:2*self.text_max_len]]
else:
if self.sampling_type == SamplingType.Random:
two_text_list = random.sample(text_list, 2)
Expand All @@ -96,23 +96,11 @@ def __getitem__(self, index):

text_list.clear()

# # load and process text
# text_path = os.path.join(self.root, 'text', text_name)
# with open(text_path, 'r', encoding='utf-8') as f:
# text_list = f.readlines()
# f.close()
# # print(text_list)
# if len(text_list) < 2:
# two_text_list = [text_list[0], text_list[0][-self.text_max_len:]]
# else:
# two_text_list = random.sample(text_list, 2)
# text_list.clear()

# print(random.sample([1,2,3,4,5,6,7,8,9,0,11,12,13,14,15,18],2))
if len(two_text_list[0]) > 256:
two_text_list[0] = two_text_list[0][:256]
if len(two_text_list[1]) > 256:
two_text_list[1] = two_text_list[1][:256]
# Don't truncate the text to 256 tokens anymore!
# if len(two_text_list[0]) > 256:
# two_text_list[0] = two_text_list[0][:256]
# if len(two_text_list[1]) > 256:
# two_text_list[1] = two_text_list[1][:256]
text1, mask1 = self.tokenizer_text(two_text_list[0])
text2, mask2 = self.tokenizer_text(two_text_list[1])

Expand Down Expand Up @@ -199,37 +187,7 @@ def tokenizer_text(self, text):


if __name__ == '__main__':
# mydataset = GraphTextDataset()
# train_loader = torch_geometric.loader.DataLoader(
# mydataset,
# batch_size=16,
# shuffle=True,
# num_workers=4
# )
# for i, (aug1, aug2, text1, mask1, text2, mask2) in enumerate(train_loader):
# print(aug1.edge_index.shape)
# print(aug1.x.shape)
# print(aug1.ptr.size(0))
# print(aug2.edge_index.dtype)
# print(aug2.x.dtype)
# print(text1.shape)
# print(mask1.shape)
# print(text2.shape)
# print(mask2.shape)
# mydataset = GraphormerPretrainDataset(root='data/', text_max_len=128, graph_aug1='dnodes', graph_aug2='subgraph')
# from functools import partial
# from data_provider.collator import collator_text
# train_loader = torch.utils.data.DataLoader(
# mydataset,
# batch_size=8,
# num_workers=4,
# collate_fn=partial(collator_text,
# max_node=128,
# multi_hop_max_dist=5,
# spatial_pos_max=1024),
# )
# aug1, aug2, text1, mask1, text2, mask2 = mydataset[0]
mydataset = GINPretrainDataset(root='data/', text_max_len=128, graph_aug1='dnodes', graph_aug2='subgraph')
mydataset = GINPretrainDataset(root='data/', text_max_len=512, graph_aug1='dnodes', graph_aug2='subgraph')
train_loader = torch_geometric.loader.DataLoader(
mydataset,
batch_size=32,
Expand Down
Binary file added text-preprocess/.DS_Store
Binary file not shown.
58 changes: 58 additions & 0 deletions text-preprocess/annotations.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
0 1
1 1
0 1
0 0
0 1
1 1
0 1
1 1
0 1
0 0
1 1
0 0
1 1
0 0
1 0
1 1
1 0
0 0
1 1
1 1
1 1
1 1
1 1
1 1
1 0
1 1
1 1
1 1
0 0
1 1
1 1
0 0
1 1
1 1
0 0
0 1
1 1
0 1
0 0
1 1
1 0
1 1
0 0
1 1
0 0
0 0
1 0
1 1
1 1
1 1
0 0
0 0
1 1
1 1
1 1
1 0
0 0
1 0
23 changes: 23 additions & 0 deletions text-preprocess/annotator_agreement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np

def cohens_kappa(annotation1, annotation2):
p_o = np.sum(annotation1==annotation2) / annotation1.shape[0]
p_e_yes = np.sum(annotation1==1)/annotation1.shape[0]*np.sum(annotation2==1)/annotation2.shape[0]
p_e_no = np.sum(annotation1==0)/annotation1.shape[0]*np.sum(annotation2==0)/annotation2.shape[0]
p_e = p_e_yes + p_e_no

return (p_o - p_e) / (1-p_e)

def preprocess(file_path):
f = open(file_path, 'r')
annotation1 = list()
annotation2 = list()
for line in f.readlines():
a1, a2 = line.split()
annotation1.append(int(a1))
annotation2.append(int(a2))
return np.array(annotation1), np.array(annotation2)

if __name__ == '__main__':
annotation1, annotation2 = preprocess('./annotations.txt')
print(cohens_kappa(annotation1, annotation2))
59 changes: 59 additions & 0 deletions text-preprocess/cosine_sim/eval_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
1=relevant,Paragraph,Romain,Andrew,AND,OR,aggreement,Ground truth label,,cosine_mean,cosine_max,cosine_sent,m,temp,exp,sum exp,softmax,epsilon,filter,odds,class, ,TP,FP,TN,FN
0=irrelevant,0,0,1,0,1,0,0,,0.711683452,0.691772163,0.53474313,,0.05,1519161.419,21380662.64,0.071053056,0.1,0.071053056,4.121077245,1,,0,1,0,0
,1,1,1,1,1,1,1,,0.652119219,0.661283016,0.49145186,,,461567.87,,0.021588099,,0.021588099,1.252109764,1,,1,0,0,0
,2,0,1,0,1,0,0,,0.541547537,0.704120457,0.376356512,,,50561.75488,,0.002364836,,0.002364836,0.137160472,1,,0,1,0,0
,3,0,0,0,0,1,0,,0.563260973,0.615489125,0.396793485,,,78058.93911,,0.003650913,,0.003650913,0.211752954,1,,0,1,0,0
,4,0,1,0,1,0,0,,0.54852593,0.726957381,0.354626387,,,58134.73362,,0.002719033,,0.002719033,0.157703931,1,,0,1,0,0
,5,1,1,1,1,1,1,,0.556598067,0.704044521,0.374056995,,,68320.2403,,0.003195422,,0.003195422,0.185334477,1,,1,0,0,0
,6,0,1,0,1,0,0,,0.661053121,0.660466313,0.42252481,,,551867.0373,,0.025811503,,0.025811503,1.497067172,1,,0,1,0,0
,7,1,1,1,1,1,1,,0.635977686,0.651206374,0.452100158,,,334219.6562,,0.015631866,,0.015631866,0.906648236,1,,1,0,0,0
,8,0,1,0,1,0,0,,0.598599017,0.757220805,0.387032747,,,158257.7534,,0.007401911,,0.007401911,0.429310815,1,,0,1,0,0
,9,0,0,0,0,1,0,,0.402815789,0.660552681,0.270457596,,,3153.649964,,0.0001475,,FALSE,0,0,,0,0,1,0
,10,1,1,1,1,1,1,,0.635977805,0.651206493,0.452100456,,,334220.453,,0.015631903,,0.015631903,0.906650397,1,,1,0,0,0
,11,0,0,0,0,1,0,,0.601019383,0.643923938,0.71905607,,,166107.0351,,0.007769031,,0.007769031,0.45060381,1,,0,1,0,0
,12,1,1,1,1,1,1,,0.667933941,0.667449594,0.438607037,,,633286.8878,,0.02961961,,0.02961961,1.717937377,1,,1,0,0,0
,13,0,0,0,0,1,0,,0.707706749,0.642222703,0.549970448,,,1403016.317,,0.065620806,,0.065620806,3.806006751,1,,0,1,0,0
,14,1,0,0,1,0,0,,0.714639187,0.719264984,0.508985817,,,1611673.66,,0.075379968,,0.075379968,4.372038127,1,,0,1,0,0
,15,1,1,1,1,1,1,,0.640178919,0.677092791,0.390590191,,,363515.9219,,0.017002089,,0.017002089,0.986121143,1,,1,0,0,0
,16,1,0,0,1,0,0,,0.517401457,0.584454536,0.364392459,,,31195.50013,,0.001459052,,FALSE,0,0,,0,0,1,0
,17,0,0,0,0,1,0,,0.461332709,0.613009691,0.303853691,,,10164.47597,,0.000475405,,FALSE,0,0,,0,0,1,0
,18,1,1,1,1,1,1,,0.629296422,0.748599708,0.405848563,,,292414.7471,,0.013676599,,0.013676599,0.793242736,1,,1,0,0,0
,19,1,1,1,1,1,1,,0.481073678,0.62320292,0.309012026,,,15085.26247,,0.000705556,,FALSE,0,0,,0,0,0,1
,20,1,1,1,1,1,1,,0.489850611,0.687767685,0.321470231,,,17979.9443,,0.000840944,,FALSE,0,0,,0,0,0,1
,21,1,1,1,1,1,1,,0.559232175,0.619754374,0.380360395,,,72015.99362,,0.003368277,,0.003368277,0.195360064,1,,1,0,0,0
,22,1,1,1,1,1,1,,0.559231639,0.619754195,0.380359352,,,72015.22098,,0.003368241,,0.003368241,0.195357968,1,,1,0,0,0
,23,1,1,1,1,1,1,,0.668402195,0.699777365,0.514539421,,,639245.5293,,0.029898303,,0.029898303,1.734101572,1,,1,0,0,0
,24,1,0,0,1,0,0,,0.649952114,0.678212225,0.460702866,,,441989.8834,,0.020672413,,0.020672413,1.198999941,1,,0,1,0,0
,25,1,1,1,1,1,1,,0.547468901,0.591249287,0.76734829,,,56918.63174,,0.002662155,,0.002662155,0.154404973,1,,1,0,0,0
,26,1,1,1,1,1,1,,0.556804955,0.730939925,0.368707657,,,68603.51834,,0.003208671,,0.003208671,0.186102935,1,,1,0,0,0
,27,1,1,1,1,1,1,,0.678954005,0.761044383,0.438413709,,,789440.7835,,0.036923121,,0.036923121,2.14154099,1,,1,0,0,0
,28,0,0,0,0,1,0,,0.593252063,0.640731156,0.391485512,,,142207.3177,,0.006651212,,0.006651212,0.38577029,1,,0,1,0,0
,29,1,1,1,1,1,1,,0.644770622,0.654364169,0.411493927,,,398479.9431,,0.018637399,,0.018637399,1.080969149,1,,1,0,0,0
,30,1,1,1,1,1,1,,0.643391848,0.646993518,0.414214224,,,387641.7829,,0.018130485,,0.018130485,1.051568129,1,,1,0,0,0
,31,0,0,0,0,1,0,,0.61413604,0.637862206,0.468578666,,,215932.428,,0.010099426,,0.010099426,0.585766729,1,,0,1,0,0
,32,1,1,1,1,1,1,,0.555880368,0.623122633,0.352636159,,,67346.57691,,0.003149883,,0.003149883,0.18269319,1,,1,0,0,0
,33,1,1,1,1,1,1,,0.53133595,0.603443861,0.292649657,,,41221.65475,,0.001927988,,0.001927988,0.111823287,1,,1,0,0,0
,34,0,0,0,0,1,0,,0.514105856,0.583254576,0.35490787,,,29205.64024,,0.001365984,,FALSE,0,0,,0,0,1,0
,35,0,1,0,1,0,0,,0.635977805,0.651206493,0.452100456,,,334220.453,,0.015631903,,0.015631903,0.906650397,1,,0,1,0,0
,36,1,1,1,1,1,1,,0.624220669,0.643128872,0.432253361,,,264187.2432,,0.012356364,,0.012356364,0.716669093,1,,1,0,0,0
,37,0,1,0,1,0,0,,0.635977924,0.651206493,0.452100456,,,334221.2499,,0.015631941,,0.015631941,0.906652559,1,,0,1,0,0
,38,0,0,0,0,1,0,,0.624511123,0.744432449,0.404354692,,,265726.3913,,0.012428352,,0.012428352,0.720844389,1,,0,1,0,0
,39,1,1,1,1,1,1,,0.670833111,0.643685579,0.532360137,,,671092.4703,,0.031387824,,0.031387824,1.820493776,1,,1,0,0,0
,40,1,0,0,1,0,0,,0.613161564,0.635882437,0.382267803,,,211764.7521,,0.009904499,,0.009904499,0.574460943,1,,0,1,0,0
,41,1,1,1,1,1,1,,0.736541033,0.665651023,0.583013415,,,2497549.617,,0.11681348,,0.11681348,6.775181863,1,,1,0,0,0
,42,0,0,0,0,1,0,,0.694615483,0.694565594,0.426671594,,,1079825.126,,0.050504755,,0.050504755,2.929275783,1,,0,1,0,0
,43,1,1,1,1,1,1,,0.53415668,0.697547913,0.312480688,,,43614.00538,,0.002039881,,0.002039881,0.118313092,1,,1,0,0,0
,44,0,0,0,0,1,0,,0.602936447,0.645063579,0.473472744,,,172599.4607,,0.00807269,,0.00807269,0.468216018,1,,0,1,0,0
,45,0,0,0,0,1,0,,0.499577224,0.70531559,0.34621048,,,21841.00587,,0.001021531,,FALSE,0,0,,0,0,1,0
,46,1,0,0,1,0,0,,0.539637208,0.611505687,0.370161563,,,48666.40132,,0.002276188,,0.002276188,0.132018887,1,,0,1,0,0
,47,1,1,1,1,1,1,,0.571132243,0.617261529,0.411967069,,,91367.47618,,0.00427337,,0.00427337,0.247855443,1,,1,0,0,0
,48,1,1,1,1,1,1,,0.499337047,0.574595571,0.30984509,,,21736.34334,,0.001016636,,FALSE,0,0,,0,0,0,1
,49,1,1,1,1,1,1,,0.647160649,0.702236056,0.506495297,,,417990.0829,,0.019549912,,0.019549912,1.133894922,1,,1,0,0,0
,50,0,0,0,0,1,0,,0.725583017,0.776126862,0.469772399,,,2006014.222,,0.093823763,,0.093823763,5.441778248,1,,0,1,0,0
,51,0,0,0,0,1,0,,0.622156799,0.648310721,0.405611992,,,253504.2778,,0.011856708,,0.011856708,0.687689075,1,,0,1,0,0
,52,1,1,1,1,1,1,,0.58341217,0.715707243,0.411937505,,,116802.9269,,0.005463017,,0.005463017,0.316854995,1,,1,0,0,0
,53,1,1,1,1,1,1,,0.577594161,0.639548183,0.418831676,,,103972.656,,0.00486293,,0.00486293,0.282049914,1,,1,0,0,0
,54,1,1,1,1,1,1,,0.616327167,0.651423156,0.373638332,,,225605.5346,,0.01055185,,0.01055185,0.612007272,1,,1,0,0,0
,55,1,0,0,1,0,0,,0.642798305,0.671691299,0.426578611,,,383067.3461,,0.017916533,,0.017916533,1.039158909,1,,0,1,0,0
,56,0,0,0,0,1,0,,0.553247511,0.607220411,0.368436098,,,63892.05053,,0.00298831,,0.00298831,0.173321987,1,,0,1,0,0
,57,1,0,0,1,0,0,,0.60133934,0.625541329,0.413838059,,,167173.388,,0.007818906,,0.007818906,0.453496539,1,,0,1,0,0
79 changes: 79 additions & 0 deletions text-preprocess/cosine_sim/intrinsic_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pandas as pd
import numpy as np
import json

def recall(ground_truth, preds):
return np.sum(np.logical_and(ground_truth == 1, preds == 1)) / np.sum(ground_truth==1)
def precision(ground_truth, preds):
return np.sum(np.logical_and(ground_truth == 1, preds == 1)) / np.sum(preds==1)
def top_p(distribution, p=0.80):
"""
Return ground truth
"""
sorted_indices = np.argsort(distribution[::-1])
distribution = distribution[sorted_indices]
cumsum_distrib = np.cumsum(distribution)
indices = sorted_indices[cumsum_distrib <= p]

preds = np.zeros(distribution.shape[0])
preds[indices] = 1
return preds
def top_k(distribution, k=20):
"""
Return ground truth
"""
sorted_indices = np.argsort(distribution[::-1])
distribution = distribution[sorted_indices]
cumsum_distrib = np.cumsum(distribution)
indices = sorted_indices[:k+1]

preds = np.zeros(distribution.shape[0])
preds[indices] = 1
return preds
def eps(distribution, eps=0.1):
"""
Return ground truth
"""
preds = np.zeros(distribution.shape[0])
preds[distribution > eps] = 1
return preds


data = pd.read_csv('./eval_data.csv')
distribution_types = ["cosine_mean", "cosine_max", "cosine_sent"]
sampling_functions = [top_p, top_k, eps]
values = [
((np.arange(9)+1)*0.1).tolist(),
[10, 20, 30, 40, 50],
[0.1, 0.05, 0.02, 0.01, 0.001]
]
ground_truth = np.array(data['Ground truth label'])

results = dict()

for distrib_type in distribution_types:
scores = np.array(data[distrib_type])
distribution = np.exp(scores) / np.sum(np.exp(scores))

if distrib_type not in results:
results[distrib_type] = dict()
for func, vals in zip(sampling_functions, values):
if func.__name__ not in results[distrib_type]:
results[distrib_type][func.__name__] = dict()

for val in vals:
preds = func(distribution, val)
rec = recall(ground_truth, preds)
prec = precision(ground_truth, preds)
results[distrib_type][func.__name__][val] = {
'precision': prec
}
print(json.dumps(results, indent=2))
with open('results.json', 'w') as f:
f.write(json.dumps(results, indent=2))






Loading

0 comments on commit 5e10174

Please sign in to comment.