1
1
use crate :: function:: OptionalArg ;
2
- use crate :: obj:: objbytes:: PyBytesRef ;
3
- use crate :: pyobject:: { PyObjectRef , PyResult } ;
2
+ use crate :: obj:: objbytearray:: { PyByteArray , PyByteArrayRef } ;
3
+ use crate :: obj:: objbyteinner:: PyBytesLike ;
4
+ use crate :: obj:: objbytes:: { PyBytes , PyBytesRef } ;
5
+ use crate :: obj:: objstr:: { PyString , PyStringRef } ;
6
+ use crate :: pyobject:: { PyObjectRef , PyResult , TryFromObject , TypeProtocol } ;
4
7
use crate :: vm:: VirtualMachine ;
8
+
5
9
use crc:: { crc32, Hasher32 } ;
10
+ use itertools:: Itertools ;
11
+
12
+ enum SerializedData {
13
+ Bytes ( PyBytesRef ) ,
14
+ Buffer ( PyByteArrayRef ) ,
15
+ Ascii ( PyStringRef ) ,
16
+ }
17
+
18
+ impl TryFromObject for SerializedData {
19
+ fn try_from_object ( vm : & VirtualMachine , obj : PyObjectRef ) -> PyResult < Self > {
20
+ match_class ! ( match obj {
21
+ b @ PyBytes => Ok ( SerializedData :: Bytes ( b) ) ,
22
+ b @ PyByteArray => Ok ( SerializedData :: Buffer ( b) ) ,
23
+ a @ PyString => {
24
+ if a. as_str( ) . is_ascii( ) {
25
+ Ok ( SerializedData :: Ascii ( a) )
26
+ } else {
27
+ Err ( vm. new_value_error(
28
+ "string argument should contain only ASCII characters" . to_string( ) ,
29
+ ) )
30
+ }
31
+ }
32
+ obj => Err ( vm. new_type_error( format!(
33
+ "argument should be bytes, buffer or ASCII string, not '{}'" ,
34
+ obj. class( ) . name,
35
+ ) ) ) ,
36
+ } )
37
+ }
38
+ }
39
+
40
+ impl SerializedData {
41
+ #[ inline]
42
+ pub fn with_ref < R > ( & self , f : impl FnOnce ( & [ u8 ] ) -> R ) -> R {
43
+ match self {
44
+ SerializedData :: Bytes ( b) => f ( b. get_value ( ) ) ,
45
+ SerializedData :: Buffer ( b) => f ( & b. inner . borrow ( ) . elements ) ,
46
+ SerializedData :: Ascii ( a) => f ( a. as_str ( ) . as_bytes ( ) ) ,
47
+ }
48
+ }
49
+ }
6
50
7
51
fn hex_nibble ( n : u8 ) -> u8 {
8
52
match n {
@@ -12,15 +56,15 @@ fn hex_nibble(n: u8) -> u8 {
12
56
}
13
57
}
14
58
15
- fn binascii_hexlify ( data : PyBytesRef , vm : & VirtualMachine ) -> PyResult {
16
- let bytes = data. get_value ( ) ;
17
- let mut hex = Vec :: < u8 > :: with_capacity ( bytes. len ( ) * 2 ) ;
18
- for b in bytes. iter ( ) {
19
- hex. push ( hex_nibble ( b >> 4 ) ) ;
20
- hex. push ( hex_nibble ( b & 0xf ) ) ;
21
- }
22
-
23
- Ok ( vm . ctx . new_bytes ( hex ) )
59
+ fn binascii_hexlify ( data : PyBytesLike , _vm : & VirtualMachine ) -> Vec < u8 > {
60
+ data. with_ref ( |bytes| {
61
+ let mut hex = Vec :: < u8 > :: with_capacity ( bytes. len ( ) * 2 ) ;
62
+ for b in bytes. iter ( ) {
63
+ hex. push ( hex_nibble ( b >> 4 ) ) ;
64
+ hex. push ( hex_nibble ( b & 0xf ) ) ;
65
+ }
66
+ hex
67
+ } )
24
68
}
25
69
26
70
fn unhex_nibble ( c : u8 ) -> Option < u8 > {
@@ -32,37 +76,66 @@ fn unhex_nibble(c: u8) -> Option<u8> {
32
76
}
33
77
}
34
78
35
- fn binascii_unhexlify ( hexstr : PyBytesRef , vm : & VirtualMachine ) -> PyResult {
36
- // TODO: allow 'str' hexstrings as well
37
- let hex_bytes = hexstr. get_value ( ) ;
38
- if hex_bytes. len ( ) % 2 != 0 {
39
- return Err ( vm. new_value_error ( "Odd-length string" . to_string ( ) ) ) ;
40
- }
79
+ fn binascii_unhexlify ( data : SerializedData , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
80
+ data. with_ref ( |hex_bytes| {
81
+ if hex_bytes. len ( ) % 2 != 0 {
82
+ return Err ( vm. new_value_error ( "Odd-length string" . to_string ( ) ) ) ;
83
+ }
41
84
42
- let mut unhex = Vec :: < u8 > :: with_capacity ( hex_bytes. len ( ) / 2 ) ;
43
- for i in ( 0 ..hex_bytes. len ( ) ) . step_by ( 2 ) {
44
- let n1 = unhex_nibble ( hex_bytes[ i] ) ;
45
- let n2 = unhex_nibble ( hex_bytes[ i + 1 ] ) ;
46
- if let ( Some ( n1) , Some ( n2) ) = ( n1, n2) {
47
- unhex. push ( n1 << 4 | n2) ;
48
- } else {
49
- return Err ( vm. new_value_error ( "Non-hexadecimal digit found" . to_string ( ) ) ) ;
85
+ let mut unhex = Vec :: < u8 > :: with_capacity ( hex_bytes. len ( ) / 2 ) ;
86
+ for ( n1, n2) in hex_bytes. iter ( ) . tuples ( ) {
87
+ if let ( Some ( n1) , Some ( n2) ) = ( unhex_nibble ( * n1) , unhex_nibble ( * n2) ) {
88
+ unhex. push ( n1 << 4 | n2) ;
89
+ } else {
90
+ return Err ( vm. new_value_error ( "Non-hexadecimal digit found" . to_string ( ) ) ) ;
91
+ }
50
92
}
51
- }
52
93
53
- Ok ( vm. ctx . new_bytes ( unhex) )
94
+ Ok ( unhex)
95
+ } )
54
96
}
55
97
56
- fn binascii_crc32 ( data : PyBytesRef , value : OptionalArg < u32 > , vm : & VirtualMachine ) -> PyResult {
57
- let bytes = data. get_value ( ) ;
58
- let crc = value. unwrap_or ( 0u32 ) ;
98
+ fn binascii_crc32 ( data : SerializedData , value : OptionalArg < u32 > , vm : & VirtualMachine ) -> PyResult {
99
+ let crc = value. unwrap_or ( 0 ) ;
59
100
60
101
let mut digest = crc32:: Digest :: new_with_initial ( crc32:: IEEE , crc) ;
61
- digest. write ( & bytes) ;
102
+ data . with_ref ( |bytes| digest. write ( & bytes) ) ;
62
103
63
104
Ok ( vm. ctx . new_int ( digest. sum32 ( ) ) )
64
105
}
65
106
107
+ #[ derive( FromArgs ) ]
108
+ struct NewlineArg {
109
+ #[ pyarg( keyword_only, default = "true" ) ]
110
+ newline : bool ,
111
+ }
112
+
113
+ /// trim a newline from the end of the bytestring, if it exists
114
+ fn trim_newline ( b : & [ u8 ] ) -> & [ u8 ] {
115
+ if b. ends_with ( b"\n " ) {
116
+ & b[ ..b. len ( ) - 1 ]
117
+ } else {
118
+ b
119
+ }
120
+ }
121
+
122
+ fn binascii_a2b_base64 ( s : SerializedData , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
123
+ s. with_ref ( |b| base64:: decode ( trim_newline ( b) ) )
124
+ . map_err ( |err| vm. new_value_error ( format ! ( "error decoding base64: {}" , err) ) )
125
+ }
126
+
127
+ fn binascii_b2a_base64 (
128
+ data : PyBytesLike ,
129
+ NewlineArg { newline } : NewlineArg ,
130
+ _vm : & VirtualMachine ,
131
+ ) -> Vec < u8 > {
132
+ let mut encoded = data. with_ref ( base64:: encode) . into_bytes ( ) ;
133
+ if newline {
134
+ encoded. push ( b'\n' ) ;
135
+ }
136
+ encoded
137
+ }
138
+
66
139
pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
67
140
let ctx = & vm. ctx ;
68
141
@@ -72,5 +145,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
72
145
"unhexlify" => ctx. new_rustfunc( binascii_unhexlify) ,
73
146
"a2b_hex" => ctx. new_rustfunc( binascii_unhexlify) ,
74
147
"crc32" => ctx. new_rustfunc( binascii_crc32) ,
148
+ "a2b_base64" => ctx. new_rustfunc( binascii_a2b_base64) ,
149
+ "b2a_base64" => ctx. new_rustfunc( binascii_b2a_base64) ,
75
150
} )
76
151
}
0 commit comments