Skip to content

Commit

Permalink
🎨 use parallelism instead of threads param #623
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Dec 30, 2022
1 parent 9655365 commit a683360
Show file tree
Hide file tree
Showing 13 changed files with 390 additions and 134 deletions.
4 changes: 3 additions & 1 deletion python/extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ltp-extension"
version = "0.1.9"
version = "0.1.10"
edition = "2021"
authors = ["ylfeng <[email protected]>"]
description = "Rust Extension For Language Technology Platform(Python)."
Expand All @@ -16,7 +16,9 @@ name = "ltp_extension"
crate-type = ["cdylib"]

[dependencies]
libc = { version = "0.2" }
rayon = { version = "1.5" }
rayon-cond = { version = "0.2" }
anyhow = { version = "1.0" }
serde = { version = "1.0", features = ["derive"] }
pyo3 = { version = "0.17", features = ["extension-module", "anyhow", "serde"] }
Expand Down
40 changes: 25 additions & 15 deletions python/extension/ltp_extension/perceptron/perceptron.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,49 @@ class CWSModel:
自定义新feature
"""
pass
def batch_predict(self, batch_text, threads=8):
def batch_predict(self, batch_text, parallelism=True):
"""
Predict batched sentences
"""
pass
def disable_cut(self, a, b):
def disable_feature_rule(self, core, feature, s, b, m, e):
"""
关闭连续不同类型之间的强制切分
移除自定义新 feature
"""
pass
def disable_cut_d(self, a, b):
def disable_type_rule(self, a, b):
"""
关闭连续不同类型之间的强制切分
关闭连续不同类型之间的强制连接/切分
"""
pass
def disable_feature_rule(self, core, feature, s, b, m, e):
def disable_type_rule_d(self, a, b):
"""
移除自定义新 feature
关闭连续不同类型之间的强制连接/切分(双向)
"""
pass
def enable_cut(self, a, b):
def enable_feature_rule(self, core, feature):
"""
开启连续不同类型之间的强制切分
启用自定义新 feature
"""
pass
def enable_cut_d(self, a, b):
def enable_type_concat(self, a, b):
"""
开启连续不同类型之间的强制连接
"""
pass
def enable_type_concat_d(self, a, b):
"""
开启连续不同类型之间的强制连接(双向)
"""
pass
def enable_type_cut(self, a, b):
"""
开启连续不同类型之间的强制切分
"""
pass
def enable_feature_rule(self, core, feature):
def enable_type_cut_d(self, a, b):
"""
启用自定义新 feature
开启连续不同类型之间的强制切分(双向)
"""
pass
@staticmethod
Expand Down Expand Up @@ -155,7 +165,7 @@ class CharacterType:
class Model:
def __init__(self, path, model_type=ModelType.Auto):
pass
def batch_predict(self, *args, threads=8):
def batch_predict(self, *args, parallelism=True):
"""
Predict batched sentences
"""
Expand Down Expand Up @@ -189,7 +199,7 @@ class ModelType:
class NERModel:
def __init__(self, path):
pass
def batch_predict(self, batch_words, batch_pos, threads=8):
def batch_predict(self, batch_words, batch_pos, parallelism=True):
"""
Predict batched sentences
"""
Expand Down Expand Up @@ -286,7 +296,7 @@ class NERTrainer:
class POSModel:
def __init__(self, path):
pass
def batch_predict(self, batch_words, threads=8):
def batch_predict(self, batch_words, parallelism=True):
"""
Predict batched sentences
"""
Expand Down
33 changes: 33 additions & 0 deletions python/extension/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod algorithms;
mod hook;
mod perceptron;
mod stnsplit;
mod utils;

use crate::perceptron::{ModelType, PyModel, PyTrainer};
pub use algorithms::{py_eisner, py_get_entities, py_viterbi_decode_postprocess};
Expand All @@ -20,9 +21,41 @@ pub use perceptron::{
use pyo3::prelude::*;
use stnsplit::StnSplit;

pub const VERSION: &str = env!("CARGO_PKG_VERSION");

// For users using multiprocessing in python, it is quite easy to fork the process running
// tokenizers, ending up with a deadlock because we internaly make use of multithreading. So
// we register a callback to be called in the event of a fork so that we can warn the user.
static mut REGISTERED_FORK_CALLBACK: bool = false;
extern "C" fn child_after_fork() {
use utils::parallelism::*;
if has_parallelism_been_used() && !is_parallelism_configured() {
println!(
"LTP: The current process just got forked, after parallelism has \
already been used. Disabling parallelism to avoid deadlocks..."
);
println!("To disable this warning, you can either:");
println!(
"\t- Avoid using `LTP/legacy` model before the fork if possible\n\
\t- Explicitly set the environment variable {}=(true | false)",
ENV_VARIABLE
);
set_parallelism(false);
}
}

/// LTP Module
#[pymodule]
fn ltp_extension(py: Python, m: &PyModule) -> PyResult<()> {
// Register the fork callback
#[cfg(target_family = "unix")]
unsafe {
if !REGISTERED_FORK_CALLBACK {
libc::pthread_atfork(None, None, Some(child_after_fork));
REGISTERED_FORK_CALLBACK = true;
}
}

m.add("__version__", env!("CARGO_PKG_VERSION"))?;

// Algorithms Module
Expand Down
77 changes: 40 additions & 37 deletions python/extension/src/perceptron/model.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::perceptron::Perceptron;
use crate::utils::parallelism::MaybeParallelIterator;
use ltp::{CWSDefinition, ModelSerde, NERDefinition, POSDefinition};
use pyo3::prelude::*;
use pyo3::types::{PyList, PyString, PyTuple};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};

Expand Down Expand Up @@ -136,9 +136,18 @@ impl PyModel {
#[pyo3(text_signature = "(self)")]
pub fn specialize(&self, py: Python) -> PyResult<PyObject> {
match &self.model {
EnumModel::CWS(model) => Ok(crate::perceptron::specialization::PyCWSModel { model: model.clone() }.into_py(py)),
EnumModel::POS(model) => Ok(crate::perceptron::specialization::PyPOSModel { model: model.clone() }.into_py(py)),
EnumModel::NER(model) => Ok(crate::perceptron::specialization::PyNERModel { model: model.clone() }.into_py(py)),
EnumModel::CWS(model) => Ok(crate::perceptron::specialization::PyCWSModel {
model: model.clone(),
}
.into_py(py)),
EnumModel::POS(model) => Ok(crate::perceptron::specialization::PyPOSModel {
model: model.clone(),
}
.into_py(py)),
EnumModel::NER(model) => Ok(crate::perceptron::specialization::PyNERModel {
model: model.clone(),
}
.into_py(py)),
}
}

Expand All @@ -159,8 +168,8 @@ impl PyModel {
Ok(())
}

#[args(args = "*", threads = 8)]
pub fn __call__(&self, py: Python, args: &PyTuple, threads: usize) -> PyResult<PyObject> {
#[args(args = "*", parallelism = true)]
pub fn __call__(&self, py: Python, args: &PyTuple, parallelism: bool) -> PyResult<PyObject> {
let first = args.get_item(0)?;
let is_single = match &self.model {
EnumModel::CWS(_) => match first.get_type().name()? {
Expand Down Expand Up @@ -195,7 +204,7 @@ impl PyModel {

match is_single {
true => self.predict(py, args),
false => self.batch_predict(py, args, threads),
false => self.batch_predict(py, args, parallelism),
}
}

Expand All @@ -213,7 +222,7 @@ impl PyModel {
.into_iter()
.map(|s| PyString::new(py, s)),
)
.into()
.into()
}
EnumModel::POS(model) => {
let words: Vec<&str> = args.get_item(0)?.extract()?;
Expand All @@ -224,7 +233,7 @@ impl PyModel {
.into_iter()
.map(|s| PyString::new(py, s)),
)
.into()
.into()
}
EnumModel::NER(model) => {
let words: Vec<&str> = args.get_item(0)?.extract()?;
Expand All @@ -236,29 +245,27 @@ impl PyModel {
.into_iter()
.map(|s| PyString::new(py, s)),
)
.into()
.into()
}
})
}

/// Predict batched sentences
#[pyo3(text_signature = "(self, *args, threads=8)")]
#[args(args = "*", threads = "8")]
pub fn batch_predict(&self, py: Python, args: &PyTuple, threads: usize) -> PyResult<PyObject> {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build()
.unwrap();

#[pyo3(text_signature = "(self, *args, parallelism = True)")]
#[args(args = "*", parallelism = true)]
pub fn batch_predict(
&self,
py: Python,
args: &PyTuple,
parallelism: bool,
) -> PyResult<PyObject> {
let result = match &self.model {
EnumModel::CWS(model) => {
let batch_text: Vec<_> = args.get_item(0)?.extract()?;
let result: Result<Vec<Vec<_>>, _> = pool.install(|| {
batch_text
.into_par_iter()
.map(|text| model.predict(text))
.collect()
});
let result: Result<Vec<Vec<_>>, _> = batch_text
.into_maybe_par_iter_cond(parallelism)
.map(|text| model.predict(text))
.collect();
let result = result?;
let res = PyList::new(py, Vec::<&PyList>::with_capacity(result.len()));
for snt in result {
Expand All @@ -272,12 +279,10 @@ impl PyModel {
}
EnumModel::POS(model) => {
let batch_words: Vec<Vec<&str>> = args.get_item(0)?.extract()?;
let result: Result<Vec<Vec<_>>, _> = pool.install(|| {
batch_words
.into_par_iter()
.map(|words| model.predict(&words))
.collect()
});
let result: Result<Vec<Vec<_>>, _> = batch_words
.into_maybe_par_iter_cond(parallelism)
.map(|words| model.predict(&words))
.collect();
let result = result?;
let res = PyList::new(py, Vec::<&PyList>::with_capacity(result.len()));
for snt in result {
Expand All @@ -292,13 +297,11 @@ impl PyModel {
EnumModel::NER(model) => {
let batch_words: Vec<Vec<&str>> = args.get_item(0)?.extract()?;
let batch_pos: Vec<Vec<&str>> = args.get_item(1)?.extract()?;
let result: Result<Vec<Vec<_>>, _> = pool.install(|| {
batch_words
.into_par_iter()
.zip(batch_pos)
.map(|(words, tags)| model.predict((&words, &tags)))
.collect()
});
let result: Result<Vec<Vec<_>>, _> = batch_words
.into_maybe_par_iter_cond(parallelism)
.zip(batch_pos)
.map(|(words, tags)| model.predict((&words, &tags)))
.collect();
let result = result?;
let res = PyList::new(py, Vec::<&PyList>::with_capacity(result.len()));
for snt in result {
Expand Down
28 changes: 11 additions & 17 deletions python/extension/src/perceptron/specialization/cws.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::impl_model;
use crate::perceptron::{Perceptron, PyAlgorithm};
use crate::utils::parallelism::MaybeParallelIterator;
use ltp::perceptron::{CWSDefinition as Definition, Trainer};
use pyo3::prelude::*;
use pyo3::types::{PyList, PyString, PyTuple};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};

pub type Model = Perceptron<Definition>;
Expand Down Expand Up @@ -123,8 +123,8 @@ impl PyCWSModel {
Ok(())
}

#[args(args = "*", threads = 8)]
pub fn __call__(&self, py: Python, args: &PyTuple, threads: usize) -> PyResult<PyObject> {
#[args(args = "*", parallelism = true)]
pub fn __call__(&self, py: Python, args: &PyTuple, parallelism: bool) -> PyResult<PyObject> {
let first = args.get_item(0)?;
let is_single = match first.get_type().name()? {
"str" => true,
Expand All @@ -139,7 +139,7 @@ impl PyCWSModel {

match is_single {
true => self.predict(py, args.get_item(0)?.extract()?),
false => self.batch_predict(py, args.get_item(0)?.extract()?, threads),
false => self.batch_predict(py, args.get_item(0)?.extract()?, parallelism),
}
}

Expand All @@ -157,24 +157,18 @@ impl PyCWSModel {
}

/// Predict batched sentences
#[args(threads = "8")]
#[pyo3(text_signature = "(self, batch_text, threads=8)")]
#[args(parallelism = true)]
#[pyo3(text_signature = "(self, batch_text, parallelism=True)")]
pub fn batch_predict(
&self,
py: Python,
batch_text: Vec<&str>,
threads: usize,
parallelism: bool,
) -> PyResult<PyObject> {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build()
.unwrap();
let result: Result<Vec<Vec<_>>, _> = pool.install(|| {
batch_text
.into_par_iter()
.map(|text| self.model.predict(text))
.collect()
});
let result: Result<Vec<Vec<_>>, _> = batch_text
.into_maybe_par_iter_cond(parallelism)
.map(|text| self.model.predict(text))
.collect();
let result = result?;
let res = PyList::new(py, Vec::<&PyList>::with_capacity(result.len()));
for snt in result {
Expand Down
Loading

0 comments on commit a683360

Please sign in to comment.