Skip to content

Commit

Permalink
[feature] Fix issue #8
Browse files Browse the repository at this point in the history
  • Loading branch information
theblackcat102 committed Jul 11, 2022
1 parent 7ad5ff2 commit 17775d3
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 5 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ In this example the value for first_value will remain 0 even though meta_pre2 al
We love contributions! Open an issue, or fork/create a pull
request.

## Develop Locally

Since extractnet relies on several C++ modules, before starting to run locally you need to compile them first

Usually what you need would be this command
```
make
```

However, you can try to build it

# More details about the code structure

Coming soon
Expand Down
12 changes: 8 additions & 4 deletions extractnet/nn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def preprocess(self, html):
return feat, blocks


def predict(self, html):
def predict(self, html, top_rank=10):
'''
html: HTML string or list of HTML string
top_rank: top K block which used to predict author, breadcrumbs(keywords), date
'''
single = False
if isinstance(html, list):
x, css, blocks= [], [], []
Expand All @@ -56,17 +60,17 @@ def predict(self, html):
inputs_onnx = { 'input': x, 'css': css }

logits = self.ort_session.run(None, inputs_onnx)[0]
decoded = self.decode_output(logits, blocks)
decoded = self.decode_output(logits, blocks, top_rank=top_rank)
return decoded[0] if single else decoded

def decode_output(self, logits, doc_blocks):
def decode_output(self, logits, doc_blocks, top_rank=10):
outputs = []
for jdx, preds in enumerate(logits):
output = {}
blocks = doc_blocks[jdx]
for idx, label in enumerate(self.label_order):
if label in ['author', 'date', 'breadcrumbs']:
top_k = 10
top_k = min(top_rank, len(preds[:, idx]))
scores = softmax([preds[:, idx]])[0]
ind = np.argpartition(preds[:, idx], -top_k)[-top_k:]
result = [ (fix_encoding(str_cast(blocks[idx].text)), scores[idx]) for idx in ind if scores[idx] > self.cls_threshold]
Expand Down
2 changes: 1 addition & 1 deletion provision.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash

cython --cplus extractnet/*.pyx
cython --cplus extractnet/features//*.pyx
cython --cplus extractnet/features/*.pyx
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
Cython>=0.21.1
beautifulsoup4==4.9.3
htmldate==1.2.3
ftfy>=4.1.0,<5.0.0
lxml
numpy>=1.19.0
Expand All @@ -7,6 +9,7 @@ pytest-cov>=2.6.0
scikit-learn>=0.22.0
scipy>=0.17.0
sklearn-crfsuite==0.3.6
tld==0.12.6
dateparser==1.1.0
joblib==0.17.0
onnxruntime==1.9.0

0 comments on commit 17775d3

Please sign in to comment.