forked from yolanda93/information_retrieval_system
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathir_evaluator.py
163 lines (141 loc) · 8.14 KB
/
ir_evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import matplotlib.pyplot as plot
import os
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
class IREvaluator(object):
"""description of class"""
#################################################################################
## @brief Constructor
# @details This method initializes the class with:
# relevance_docs It contains relevance assessments for each query in MED.QRY
# ranking_query
#################################################################################
def __init__(self,relevance_docs,ranking_query,continue_eval,only_query_id):
self.relevance_docs=relevance_docs
self.continue_eval=continue_eval
query_id=1
if len(ranking_query) >1:
for q in ranking_query-1:
print("\n-------------------------->Query = " + str(query_id) )
relevants_docs_query=self.get_total_relevant_docs(query_id)
ranking_query[q] = [ranking_query[q][i] for i in range(len(ranking_query[q])) if ranking_query[q][i][1] > 0.0]
self.evaluate_query(ranking_query[q],relevants_docs_query,query_id)
query_id += 1
else:
print("\n-------------------------->Query = " + str(only_query_id) )
relevants_docs_query=self.get_total_relevant_docs(only_query_id)
ranking_query[1] = [ranking_query[1][i] for i in range(len(ranking_query[1])) if ranking_query[1][i][1] > 0.0]
self.evaluate_query(ranking_query[1],relevants_docs_query,only_query_id)
#################################################################################
## @brief evaluate_query
# @details This method computes the precision and recall for the provided query
# @param ranking Ranking result for each query
# @param relevance_docs It contains relevance assessments for each query in MED.QRY
# @param query_id Query id
#################################################################################
def evaluate_query(self,ranking,relevants_docs_query,query_id):
if(self.continue_eval):
[true_positives, false_positives] = self.relevant_doc_retrieved(query_id,ranking,relevants_docs_query)
recall = self.get_recall(true_positives,len(relevants_docs_query[query_id]))
precision = self.get_precision(true_positives,false_positives)
# compute total precision and recall
print(" Precision: " + str(precision) + "\n")
print(" Recall: " + str(recall) + "\n")
true_positives = 0
false_positives = 0
recall = []
precision = []
for doc in ranking:
if str(doc[0]) in relevants_docs_query[query_id]: # position 3 indicates document ID
true_positives += 1
else:
false_positives += 1
recall.append(self.get_recall(true_positives,len(relevants_docs_query[query_id])))
precision.append(self.get_precision(true_positives,false_positives))
recalls_levels = np.array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ])
interpolated_precisions = self.interpolate_precisions(recall,precision,recalls_levels)
self.plot_results(recalls_levels, interpolated_precisions)
else: # Show the final results
plot.show()
plot.close()
return
#################################################################################
## @brief total_relevant_docs
# @details This method returns the total relevant documents for a query.
# @param query_id The id of the query
# @param relevance_docs It contains relevance assessments for each query in MED.QRY
#################################################################################
def relevant_doc_retrieved(self,query,ranking,relevants_docs_query):
true_positives = 0
false_positives = 0
for doc in ranking:
if str(doc[0]) in relevants_docs_query[query]: # position 3 indicates document ID
true_positives += 1
else:
false_positives += 1
return true_positives,false_positives
#################################################################################
## @brief total_relevant_docs
# @details This method returns the total relevant documents for a query.
# @param relevants_docs_query is a dictionary that stores the query key and the relevant documents IDs
#################################################################################
def get_total_relevant_docs(self,query_id):
relevants_query=dict()
relevants_docs_query=[] # stores the relevant docs for a query
for doc in self.relevance_docs:
if(doc[0]==query_id): # the position 0 contains the query ID
relevants_docs_query.append(doc[2]) # position 3 indicates document ID
relevants_query[query_id]=relevants_docs_query
return relevants_query
#################################################################################
## @brief get_recall
# @details A measure of the ability of a system to present all relevant items.
# @param true_positives retrieved documents correctly
# @param false_negatives retrieved documents incorrectly
# @param real_true_positives total of documents that are really relevant
#################################################################################
def get_recall(self,true_positives,real_true_positives):
recall=float(true_positives)/float(real_true_positives)
return recall
#################################################################################
## @brief get_precision
# @details A measure of the ability of a system to present only relevant items.
# @param true_positives retrieved documents correctly
# @param false_negatives retrieved documents incorrectly
#################################################################################
def get_precision(self,true_positives,false_positives):
relevant_items_retrieved=true_positives+false_positives
precision=float(true_positives)/float(relevant_items_retrieved)
return precision
#################################################################################
## @brief interpolate_precisions
# @details individual topic precision values are interpolated to
# a set of standard recall levels (0 to 1 in increments of .1)
# @param recall retrieved documents correctly
# @param precision retrieved documents incorrectly
# @param recalls_levels the standard recall levels
#################################################################################
def interpolate_precisions(self,recalls,precisions, recalls_levels):
precisions_interpolated = np.zeros((len(recalls), len(recalls_levels)))
i = 0
while i < len(precisions):
# use the max precision obtained for the topic for any actual recall level greater than or equal the recall_levels
recalls_inter = np.where((recalls[i] > recalls_levels) == True)[0]
for recall_id in recalls_inter:
if precisions[i] > precisions_interpolated[i, recall_id]:
precisions_interpolated[i, recall_id] = precisions[i]
i += 1
mean_interpolated_precisions = np.mean(precisions_interpolated, axis=0)
return mean_interpolated_precisions
#################################################################################
## @brief plot_results
# @details plot the result of evaluate each query
# @param recall retrieved documents correctly
# @param precision retrieved documents incorrectly
#################################################################################
def plot_results(self,recall, precision):
plot.plot(recall, precision)
plot.xlabel('recall')
plot.ylabel('precision')
plot.draw()
plot.title('P/R curves')