Skip to content

Commit fbaeecc

Browse files
authored
Merge pull request RustPython#5709 from coolreader18/bz2
Switch to `libbz2-rs-sys` and finish bz2 impl
2 parents 974c54e + f0d46bf commit fbaeecc

File tree

8 files changed

+314
-214
lines changed

8 files changed

+314
-214
lines changed

Cargo.lock

Lines changed: 9 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ flame-it = ["rustpython-vm/flame-it", "flame", "flamescope"]
1919
freeze-stdlib = ["stdlib", "rustpython-vm/freeze-stdlib", "rustpython-pylib?/freeze-stdlib"]
2020
jit = ["rustpython-vm/jit"]
2121
threading = ["rustpython-vm/threading", "rustpython-stdlib/threading"]
22-
bz2 = ["stdlib", "rustpython-stdlib/bz2"]
2322
sqlite = ["rustpython-stdlib/sqlite"]
2423
ssl = ["rustpython-stdlib/ssl"]
2524
ssl-vendor = ["ssl", "rustpython-stdlib/ssl-vendor"]

Lib/test/test_bz2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,8 @@ def testCompress4G(self, size):
676676
finally:
677677
data = None
678678

679+
# TODO: RUSTPYTHON
680+
@unittest.expectedFailure
679681
def testPickle(self):
680682
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
681683
with self.assertRaises(TypeError):
@@ -734,6 +736,8 @@ def testDecompress4G(self, size):
734736
compressed = None
735737
decompressed = None
736738

739+
# TODO: RUSTPYTHON
740+
@unittest.expectedFailure
737741
def testPickle(self):
738742
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
739743
with self.assertRaises(TypeError):
@@ -1001,6 +1005,8 @@ def test_encoding_error_handler(self):
10011005
as f:
10021006
self.assertEqual(f.read(), "foobar")
10031007

1008+
# TODO: RUSTPYTHON
1009+
@unittest.expectedFailure
10041010
def test_newline(self):
10051011
# Test with explicit newline (universal newline mode disabled).
10061012
text = self.TEXT.decode("ascii")

stdlib/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ license.workspace = true
1414
default = ["compiler"]
1515
compiler = ["rustpython-vm/compiler"]
1616
threading = ["rustpython-common/threading", "rustpython-vm/threading"]
17-
bz2 = ["bzip2"]
1817
sqlite = ["dep:libsqlite3-sys"]
1918
ssl = ["openssl", "openssl-sys", "foreign-types-shared", "openssl-probe"]
2019
ssl-vendor = ["ssl", "openssl/vendored"]
@@ -80,7 +79,7 @@ adler32 = "1.2.0"
8079
crc32fast = "1.3.2"
8180
flate2 = { version = "1.1", default-features = false, features = ["zlib-rs"] }
8281
libz-sys = { package = "libz-rs-sys", version = "0.5" }
83-
bzip2 = { version = "0.4", optional = true }
82+
bzip2 = { version = "0.5", features = ["libbz2-rs-sys"] }
8483

8584
# tkinter
8685
tk-sys = { git = "https://github.com/arihant2math/tkinter.git", tag = "v0.1.0", optional = true }

stdlib/src/bz2.rs

Lines changed: 46 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,48 @@ mod _bz2 {
1212
object::{PyPayload, PyResult},
1313
types::Constructor,
1414
};
15+
use crate::zlib::{
16+
DecompressArgs, DecompressError, DecompressState, DecompressStatus, Decompressor,
17+
};
1518
use bzip2::{Decompress, Status, write::BzEncoder};
19+
use rustpython_vm::convert::ToPyException;
1620
use std::{fmt, io::Write};
1721

18-
// const BUFSIZ: i32 = 8192;
19-
20-
struct DecompressorState {
21-
decoder: Decompress,
22-
eof: bool,
23-
needs_input: bool,
24-
// input_buffer: Vec<u8>,
25-
// output_buffer: Vec<u8>,
26-
}
22+
const BUFSIZ: usize = 8192;
2723

2824
#[pyattr]
2925
#[pyclass(name = "BZ2Decompressor")]
3026
#[derive(PyPayload)]
3127
struct BZ2Decompressor {
32-
state: PyMutex<DecompressorState>,
28+
state: PyMutex<DecompressState<Decompress>>,
29+
}
30+
31+
impl Decompressor for Decompress {
32+
type Flush = ();
33+
type Status = Status;
34+
type Error = bzip2::Error;
35+
36+
fn total_in(&self) -> u64 {
37+
self.total_in()
38+
}
39+
fn decompress_vec(
40+
&mut self,
41+
input: &[u8],
42+
output: &mut Vec<u8>,
43+
(): Self::Flush,
44+
) -> Result<Self::Status, Self::Error> {
45+
self.decompress_vec(input, output)
46+
}
47+
}
48+
49+
impl DecompressStatus for Status {
50+
fn is_stream_end(&self) -> bool {
51+
*self == Status::StreamEnd
52+
}
3353
}
3454

3555
impl fmt::Debug for BZ2Decompressor {
36-
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3757
write!(f, "_bz2.BZ2Decompressor")
3858
}
3959
}
@@ -43,13 +63,7 @@ mod _bz2 {
4363

4464
fn py_new(cls: PyTypeRef, _: Self::Args, vm: &VirtualMachine) -> PyResult {
4565
Self {
46-
state: PyMutex::new(DecompressorState {
47-
decoder: Decompress::new(false),
48-
eof: false,
49-
needs_input: true,
50-
// input_buffer: Vec::new(),
51-
// output_buffer: Vec::new(),
52-
}),
66+
state: PyMutex::new(DecompressState::new(Decompress::new(false), vm)),
5367
}
5468
.into_ref_with_type(vm, cls)
5569
.map(Into::into)
@@ -59,107 +73,34 @@ mod _bz2 {
5973
#[pyclass(with(Constructor))]
6074
impl BZ2Decompressor {
6175
#[pymethod]
62-
fn decompress(
63-
&self,
64-
data: ArgBytesLike,
65-
// TODO: PyIntRef
66-
max_length: OptionalArg<i32>,
67-
vm: &VirtualMachine,
68-
) -> PyResult<PyBytesRef> {
69-
let max_length = max_length.unwrap_or(-1);
70-
if max_length >= 0 {
71-
return Err(vm.new_not_implemented_error(
72-
"the max_value argument is not implemented yet".to_owned(),
73-
));
74-
}
75-
// let max_length = if max_length < 0 || max_length >= BUFSIZ {
76-
// BUFSIZ
77-
// } else {
78-
// max_length
79-
// };
76+
fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
77+
let max_length = args.max_length();
78+
let data = &*args.data();
8079

8180
let mut state = self.state.lock();
82-
let DecompressorState {
83-
decoder,
84-
eof,
85-
..
86-
// needs_input,
87-
// input_buffer,
88-
// output_buffer,
89-
} = &mut *state;
90-
91-
if *eof {
92-
return Err(vm.new_exception_msg(
93-
vm.ctx.exceptions.eof_error.to_owned(),
94-
"End of stream already reached".to_owned(),
95-
));
96-
}
97-
98-
// data.with_ref(|data| input_buffer.extend(data));
99-
100-
// If max_length is negative:
101-
// read the input X bytes at a time, compress it and append it to output.
102-
// Once you're out of input, setting needs_input to true and return the
103-
// output as bytes.
104-
//
105-
// TODO:
106-
// If max_length is non-negative:
107-
// Read the input X bytes at a time, compress it and append it to
108-
// the output. If output reaches `max_length` in size, return
109-
// it (up to max_length), and store the rest of the output
110-
// for later.
111-
112-
// TODO: arbitrary choice, not the right way to do it.
113-
let mut buf = Vec::with_capacity(data.len() * 32);
114-
115-
let before = decoder.total_in();
116-
let res = data.with_ref(|data| decoder.decompress_vec(data, &mut buf));
117-
let _written = (decoder.total_in() - before) as usize;
118-
119-
let res = match res {
120-
Ok(x) => x,
121-
// TODO: error message
122-
_ => return Err(vm.new_os_error("Invalid data stream".to_owned())),
123-
};
124-
125-
if res == Status::StreamEnd {
126-
*eof = true;
127-
}
128-
Ok(vm.ctx.new_bytes(buf.to_vec()))
81+
state
82+
.decompress(data, max_length, BUFSIZ, vm)
83+
.map_err(|e| match e {
84+
DecompressError::Decompress(err) => vm.new_os_error(err.to_string()),
85+
DecompressError::Eof(err) => err.to_pyexception(vm),
86+
})
12987
}
13088

13189
#[pygetset]
13290
fn eof(&self) -> bool {
133-
let state = self.state.lock();
134-
state.eof
91+
self.state.lock().eof()
13592
}
13693

13794
#[pygetset]
138-
fn unused_data(&self, vm: &VirtualMachine) -> PyBytesRef {
139-
// Data found after the end of the compressed stream.
140-
// If this attribute is accessed before the end of the stream
141-
// has been reached, its value will be b''.
142-
vm.ctx.new_bytes(b"".to_vec())
143-
// alternatively, be more honest:
144-
// Err(vm.new_not_implemented_error(
145-
// "unused_data isn't implemented yet".to_owned(),
146-
// ))
147-
//
148-
// TODO
149-
// let state = self.state.lock();
150-
// if state.eof {
151-
// vm.ctx.new_bytes(state.input_buffer.to_vec())
152-
// else {
153-
// vm.ctx.new_bytes(b"".to_vec())
154-
// }
95+
fn unused_data(&self) -> PyBytesRef {
96+
self.state.lock().unused_data()
15597
}
15698

15799
#[pygetset]
158100
fn needs_input(&self) -> bool {
159101
// False if the decompress() method can provide more
160102
// decompressed data before requiring new uncompressed input.
161-
let state = self.state.lock();
162-
state.needs_input
103+
self.state.lock().needs_input()
163104
}
164105

165106
// TODO: mro()?
@@ -178,7 +119,7 @@ mod _bz2 {
178119
}
179120

180121
impl fmt::Debug for BZ2Compressor {
181-
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
122+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182123
write!(f, "_bz2.BZ2Compressor")
183124
}
184125
}

stdlib/src/lib.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ mod statistics;
3636
mod suggestions;
3737
// TODO: maybe make this an extension module, if we ever get those
3838
// mod re;
39-
#[cfg(feature = "bz2")]
4039
mod bz2;
4140
#[cfg(not(target_arch = "wasm32"))]
4241
pub mod socket;
@@ -112,6 +111,7 @@ pub fn get_module_inits() -> impl Iterator<Item = (Cow<'static, str>, StdlibInit
112111
"array" => array::make_module,
113112
"binascii" => binascii::make_module,
114113
"_bisect" => bisect::make_module,
114+
"_bz2" => bz2::make_module,
115115
"cmath" => cmath::make_module,
116116
"_contextvars" => contextvars::make_module,
117117
"_csv" => csv::make_module,
@@ -158,10 +158,6 @@ pub fn get_module_inits() -> impl Iterator<Item = (Cow<'static, str>, StdlibInit
158158
{
159159
"_ssl" => ssl::make_module,
160160
}
161-
#[cfg(feature = "bz2")]
162-
{
163-
"_bz2" => bz2::make_module,
164-
}
165161
#[cfg(windows)]
166162
{
167163
"_overlapped" => overlapped::make_module,

0 commit comments

Comments
 (0)