Skip to content

Commit 11456b4

Browse files
Prithivi DaPrithivi Da
Prithivi Da
authored and
Prithivi Da
committed
Added support for RankT5
1 parent 474d9e2 commit 11456b4

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

README.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Ultra-lite & Super-fast Python library to add re-ranking to your existing se
1919
- Below are the list of models supported as of now.
2020
* ms-marco-TinyBERT-L-2-v2 (default)
2121
* ms-marco-MiniLM-L-12-v2
22+
* rank-T5-flan (Best non cross-encoder reranker)
2223
* ms-marco-MultiBERT-L-12 (Multi-lingual, [supports 100+ languages](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages))
2324

2425
- Why only sleeker models? Reranking is the final leg of larger retrieval pipelines, idea is to avoid any extra overhead especially for user-facing scenarios. To that end models with really small footprint that doesn't need any specialised hardware and yet offer competitive performance are chosen. Feel free to raise issues to add support for a new models as you see fit.
@@ -42,7 +43,12 @@ ranker = Ranker(model_name="ms-marco-MiniLM-L-12-v2", cache_dir="/opt")
4243

4344
or
4445

45-
# Medium (~150MB), slower model with best performance (ranking precision) for 100+ languages including en.
46+
# Medium (~110MB), slower model with best zeroshot performance (ranking precision) on out of domain data.
47+
ranker = Ranker(model_name="rank-T5-flan", cache_dir="/opt")
48+
49+
or
50+
51+
# Medium (~150MB), slower model with competitive performance (ranking precision) for 100+ languages (don't use for english)
4652
ranker = Ranker(model_name="ms-marco-MultiBERT-L-12", cache_dir="/opt")
4753
```
4854

@@ -87,7 +93,7 @@ print(results)
8793
## Deployment patterns
8894
#### How to use it in a AWS Lambda function ?
8995
In AWS or other serverless environments the entire VM is read-only you might have to create your
90-
own custom dir and use it for loading the models (and eventually as a cache between warm calls). You can do it during init with cache_dir parameter.
96+
own custom dir. You can do so in your Dockerfile and use it for loading the models (and eventually as a cache between warm calls). You can do it during init with cache_dir parameter.
9197

9298
```python
9399
ranker = Ranker(model_name="ms-marco-MiniLM-L-12-v2", cache_dir="/opt")

flashrank/Config.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
model_file_map = {
55
"ms-marco-TinyBERT-L-2-v2": "flashrank-TinyBERT-L-2-v2.onnx",
66
"ms-marco-MiniLM-L-12-v2": "flashrank-MiniLM-L-12-v2_Q.onnx",
7-
"ms-marco-MultiBERT-L-12": "flashrank-MultiBERT-L12_Q.onnx"
7+
"ms-marco-MultiBERT-L-12": "flashrank-MultiBERT-L12_Q.onnx",
8+
"rank-T5-flan": "flashrank-rankt5_Q.onnx"
89
}

flashrank/Ranker.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -112,22 +112,29 @@ def _get_tokenizer(self, max_length = 512):
112112
return tokenizer
113113

114114

115-
116115
def rerank(self, query, passages):
117116

118117
query_passage_pairs = [[query, passage] for passage in passages]
119118
input_text = self.tokenizer.encode_batch(query_passage_pairs)
120119
input_ids = np.array([e.ids for e in input_text])
121120
token_type_ids = np.array([e.type_ids for e in input_text])
122121
attention_mask = np.array([e.attention_mask for e in input_text])
122+
123+
use_token_type_ids = token_type_ids is not None and not np.all(token_type_ids == 0)
124+
125+
if use_token_type_ids:
126+
onnx_input = {
127+
"input_ids": np.array(input_ids, dtype=np.int64),
128+
"attention_mask": np.array(attention_mask, dtype=np.int64),
129+
"token_type_ids": np.array(token_type_ids, dtype=np.int64),
130+
}
131+
else:
132+
onnx_input = {
133+
"input_ids": np.array(input_ids, dtype=np.int64),
134+
"attention_mask": np.array(attention_mask, dtype=np.int64)
135+
}
123136

124137

125-
onnx_input = {
126-
"input_ids": np.array(input_ids, dtype=np.int64),
127-
"attention_mask": np.array(attention_mask, dtype=np.int64),
128-
"token_type_ids": np.array(token_type_ids, dtype=np.int64),
129-
}
130-
131138
input_data = {k: v for k, v in onnx_input.items()}
132139

133140
outputs = self.session.run(None, input_data)
@@ -149,5 +156,4 @@ def rerank(self, query, passages):
149156
})
150157

151158

152-
return passage_info
153-
159+
return passage_info

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='FlashRank',
5-
version='0.1.4',
5+
version='0.1.5',
66
packages=find_packages(),
77
install_requires=[
88
'tokenizers',

0 commit comments

Comments
 (0)