Skip to content

Commit

Permalink
Added pipeline merging both components
Browse files Browse the repository at this point in the history
  • Loading branch information
theblackcat102 committed Apr 11, 2022
1 parent 98044cd commit 07eb1de
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 280 deletions.
8 changes: 5 additions & 3 deletions extractnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from extractnet.blocks import Blockifier, PartialBlock, BlockifyError
from extractnet import features
from extractnet.hybrid_extractor import Extractor
from extractnet.pipeline import Extractor

_LOADED_MODELS = {}


def extract_content(html, encoding=None, as_blocks=False):
pass
if 'news_extraction' not in _LOADED_MODELS:
_LOADED_MODELS['news_extraction'] = Extractor()

return _LOADED_MODELS['news_extraction'].predict(html)

def extract_comments(html, encoding=None, as_blocks=False):
pass
Expand Down
273 changes: 0 additions & 273 deletions extractnet/hybrid_extractor.py

This file was deleted.

Binary file added extractnet/models/news_net.onnx
Binary file not shown.
10 changes: 6 additions & 4 deletions extractnet/nn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@

class NewsNet():
'''
Inputs
'''
# order must be fixed
label_order = ('content', 'author', 'headline', 'breadcrumbs', 'date')

BASE_FEAT_SIZE = 9

CSS_FEAT_SIZE = 43
feats = ('kohlschuetter', 'weninger', 'readability', 'css')

def __init__(self, cls_threshold=0.1, binary_threshold=0.5):
def __init__(self, model_weight=None, cls_threshold=0.1, binary_threshold=0.5):
self.feature_transform = get_and_union_features(self.feats)
self.ort_session = ort.InferenceSession(get_module_res('models/news_net.onnx'))
model_weight = get_module_res('models/news_net.onnx') if model_weight is None else model_weight
self.ort_session = ort.InferenceSession(model_weight)
self.binary_threshold = binary_threshold
self.cls_threshold = cls_threshold

Expand Down Expand Up @@ -68,7 +69,8 @@ def decode_output(self, logits, blocks):
scores = softmax([preds[:, idx]])[0]
ind = np.argpartition(preds[:, idx], -top_k)[-top_k:]
result = [ (fix_encoding(str_cast(blocks[idx].text), scores[idx])) for idx in ind if scores[idx] > self.cls_threshold]
output[label] = result
# sort values by confidence
output[label] = sorted(result, lambda x:x[1], reverse=True)
else:
mask = expit(preds[:, idx]) > self.binary_threshold
ctx = fix_encoding(str_cast(b'\n'.join([ b.text for b in blocks[mask]])))
Expand Down
Loading

0 comments on commit 07eb1de

Please sign in to comment.