Skip to content

Commit

Permalink
add user defined post and pre processing components
Browse files Browse the repository at this point in the history
  • Loading branch information
theblackcat102 committed Apr 11, 2022
1 parent 07eb1de commit f6f458f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 16 deletions.
7 changes: 1 addition & 6 deletions extractnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,9 @@

_LOADED_MODELS = {}

def extract_content(html, encoding=None, as_blocks=False):
def extract_news(html, encoding=None, as_blocks=False):
if 'news_extraction' not in _LOADED_MODELS:
_LOADED_MODELS['news_extraction'] = Extractor()

return _LOADED_MODELS['news_extraction'].predict(html)

def extract_comments(html, encoding=None, as_blocks=False):
pass

def extract_content_and_comments(html, encoding=None, as_blocks=False):
pass
59 changes: 53 additions & 6 deletions extractnet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,29 @@
import numpy as np
import dateparser
from sklearn.base import BaseEstimator, ClassifierMixin
from .util import get_and_union_features, get_module_res, fix_encoding
from .metadata_extraction.metadata import extract_metadata

from .compat import unicode_
from .util import priority_merge, get_module_res, remove_empty_keys
from .nn_models import NewsNet
from .name_crf import AuthorExtraction
from .nn_models import NewsNet


class Extractor(BaseEstimator, ClassifierMixin):

def __init__(self, author_extractor=None, content_extractor=None):
def __init__(self, author_extractor=None, content_extractor=None, postprocess=[],
meta_postprocess=[]):
if author_extractor is None:
author_extractor = AuthorExtraction()
if content_extractor is None:
content_extractor = NewsNet()
self.meta_postprocess_pipelines = meta_postprocess
self.has_meta_pos = len(meta_postprocess) > 0

self.postprocess_pipelines = postprocess
self.has_post = len(postprocess) > 0

self.author_extractor = author_extractor
self.content_extractor = content_extractor
self.output_attributes = self.content_extractor.label_order
Expand All @@ -34,14 +43,44 @@ def from_pretrained(directory=None):
NewsNet(model_weight=nn_weight_path)
)

def extract(self, html, encoding=None, as_blocks=False, extract_target=None, debug=True):
@staticmethod
def extract_one_meta(document):
ml_fallback = {}
meta_data = extract_metadata(document)
meta_data = remove_empty_keys(meta_data)

return meta_data

def extract(self, html, encoding=None, as_blocks=False, extract_target=None, debug=True, metadata_mining=True):
if isinstance(html, (str, bytes, unicode_, np.unicode_)):
documents_meta_data = {}
if metadata_mining:
documents_meta_data = self.extract_one_meta(html)
if self.has_meta_pos:
for pipeline in self.meta_postprocess_pipelines:
meta_post_result = pipeline(html)
documents_meta_data = priority_merge(meta_post_result, documents_meta_data)
else: # must be a list
documents_meta_data = []
if metadata_mining:
for document in html:
document_meta_data, meta_ml_fallback = self.extract_one_meta(document)
if self.has_meta_pos:
for pipeline in self.meta_postprocess_pipelines:
meta_post_result = pipeline(document)
document_meta_data = priority_merge(meta_post_result, document_meta_data)

documents_meta_data.append(document_meta_data)
else:
documents_meta_data = [{}] * len(html)

output = self.content_extractor.predict(html)
if isinstance(output, dict):
return self.postprocess(output)
return self.postprocess(html, output, documents_meta_data)

return [ self.postprocess(o) for o in output]
return [ self.postprocess(h, o, meta) for h, o, meta in zip(html, output, documents_meta_data)]

def postprocess(self, output):
def postprocess(self, html, output, meta):
results = {}
if 'author' in output and len(output['author']) > 0:
author_text, confidence = output['author'][0]
Expand Down Expand Up @@ -69,4 +108,12 @@ def postprocess(self, output):
else:
# is list of tuple (string, float) format
results[attribute] = [val[0] for val in value ]

results = priority_merge(results, meta)

if self.has_post:
for pipeline in self.postprocess_pipelines:
post_ml_results_ = pipeline(html, results)
results = priority_merge(post_ml_results_, ml_results)

return results
5 changes: 1 addition & 4 deletions test/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ def html():


def test_extractor(html):
auth_clf = joblib.load('extractnet/models/author_extractor.pkl.gz')
date_clf = joblib.load('extractnet/models/datePublishedRaw_extractor.pkl.gz')
extractor = Extractor('extractnet/models/final_extractor.pkl.gz',
auth_clf, date_clf)
extractor = Extractor()

results = extractor.extract(html, metadata_mining=False)

Expand Down

0 comments on commit f6f458f

Please sign in to comment.