Skip to content

Commit 8cb3673

Browse files
committed
Factor out common logic in internal RSA padding interface.
Factor out the duplicate checks that `m` is fully consumed.
1 parent ba8199b commit 8cb3673

File tree

2 files changed

+124
-122
lines changed

2 files changed

+124
-122
lines changed

src/rsa/padding.rs

+121-120
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub trait Encoding: Sync {
2424

2525
/// The term "Verification" comes from RFC 3447.
2626
pub trait Verification: Sync {
27-
fn verify(&self, msg: untrusted::Input, m: untrusted::Input,
27+
fn verify(&self, msg: untrusted::Input, m: &mut untrusted::Reader,
2828
mod_bits: usize) -> Result<(), error::Unspecified>;
2929
}
3030

@@ -63,47 +63,48 @@ impl Encoding for PKCS1 {
6363
}
6464

6565
impl Verification for PKCS1 {
66-
fn verify(&self, msg: untrusted::Input, m: untrusted::Input,
66+
fn verify(&self, msg: untrusted::Input, m: &mut untrusted::Reader,
6767
_mod_bits: usize) -> Result<(), error::Unspecified> {
68-
m.read_all(error::Unspecified, |em| {
69-
if try!(em.read_byte()) != 0 ||
70-
try!(em.read_byte()) != 1 {
71-
return Err(error::Unspecified);
72-
}
68+
let em = m;
7369

74-
let mut ps_len = 0;
75-
loop {
76-
match try!(em.read_byte()) {
77-
0xff => {
78-
ps_len += 1;
79-
},
80-
0x00 => {
81-
break;
82-
},
83-
_ => {
84-
return Err(error::Unspecified);
85-
},
86-
}
87-
}
88-
if ps_len < 8 {
89-
return Err(error::Unspecified);
90-
}
70+
if try!(em.read_byte()) != 0 ||
71+
try!(em.read_byte()) != 1 {
72+
return Err(error::Unspecified);
73+
}
9174

92-
let em_digestinfo_prefix = try!(em.skip_and_get_input(
93-
self.digestinfo_prefix.len()));
94-
if em_digestinfo_prefix != self.digestinfo_prefix {
95-
return Err(error::Unspecified);
75+
let mut ps_len = 0;
76+
loop {
77+
match try!(em.read_byte()) {
78+
0xff => {
79+
ps_len += 1;
80+
},
81+
0x00 => {
82+
break;
83+
},
84+
_ => {
85+
return Err(error::Unspecified);
86+
},
9687
}
88+
}
89+
if ps_len < 8 {
90+
return Err(error::Unspecified);
91+
}
9792

98-
let digest_alg = self.digest_alg;
99-
let decoded_digest =
100-
try!(em.skip_and_get_input(digest_alg.output_len));
101-
let digest = digest::digest(digest_alg, msg.as_slice_less_safe());
102-
if decoded_digest != digest.as_ref() {
103-
return Err(error::Unspecified);
104-
}
105-
Ok(())
106-
})
93+
let em_digestinfo_prefix = try!(em.skip_and_get_input(
94+
self.digestinfo_prefix.len()));
95+
if em_digestinfo_prefix != self.digestinfo_prefix {
96+
return Err(error::Unspecified);
97+
}
98+
99+
let digest_alg = self.digest_alg;
100+
let decoded_digest =
101+
try!(em.skip_and_get_input(digest_alg.output_len));
102+
let digest = digest::digest(digest_alg, msg.as_slice_less_safe());
103+
if decoded_digest != digest.as_ref() {
104+
return Err(error::Unspecified);
105+
}
106+
107+
Ok(())
107108
}
108109
}
109110

@@ -175,102 +176,100 @@ const PSS_PREFIX_ZEROS: [u8; 8] = [0u8; 8];
175176
impl Verification for PSS {
176177
// RSASSA-PSS-VERIFY from https://tools.ietf.org/html/rfc3447#section-8.1.2
177178
// where steps 1, 2(a), and 2(b) have been done for us.
178-
fn verify(&self, msg: untrusted::Input, m: untrusted::Input,
179+
fn verify(&self, msg: untrusted::Input, m: &mut untrusted::Reader,
179180
mod_bits: usize) -> Result<(), error::Unspecified> {
180-
m.read_all(error::Unspecified, |m| {
181-
// RSASSA-PSS-VERIFY Step 2(c). The `m` this function is given is
182-
// the big-endian-encoded value of `m` from the specification,
183-
// padded to `k` bytes, where `k` is the length in bytes of the
184-
// public modulus. The spec says "Note that emLen will be one less
185-
// than k if modBits - 1 is divisible by 8 and equal to k
186-
// otherwise," where `k` is the length in octets of the RSA public
187-
// modulus `n`. In other words, `em` might have an extra leading
188-
// zero byte that we need to strip before we start the PSS decoding
189-
// steps which is an artifact of the `Verification` interface.
190-
if (mod_bits - 1) % 8 == 0 {
191-
if try!(m.read_byte()) != 0 {
192-
return Err(error::Unspecified);
193-
}
194-
};
195-
let em = m;
196-
let em_bits = mod_bits - 1;
197-
let em_len = (em_bits + 7) / 8;
198-
let top_byte_mask = 0xffu8 >> ((8 * em_len) - em_bits);
181+
// RSASSA-PSS-VERIFY Step 2(c). The `m` this function is given is the
182+
// big-endian-encoded value of `m` from the specification, padded to
183+
// `k` bytes, where `k` is the length in bytes of the public modulus.
184+
// The spec. says "Note that emLen will be one less than k if
185+
// modBits - 1 is divisible by 8 and equal to k otherwise," where `k`
186+
// is the length in octets of the RSA public modulus `n`. In other
187+
// words, `em` might have an extra leading zero byte that we need to
188+
// strip before we start the PSS decoding steps which is an artifact of
189+
// the `Verification` interface.
190+
if (mod_bits - 1) % 8 == 0 {
191+
if try!(m.read_byte()) != 0 {
192+
return Err(error::Unspecified);
193+
}
194+
};
195+
let em = m;
196+
let em_bits = mod_bits - 1;
197+
let em_len = (em_bits + 7) / 8;
198+
let top_byte_mask = 0xffu8 >> ((8 * em_len) - em_bits);
199199

200-
// The rest of this function is EMSA-PSS-VERIFY from
201-
// https://tools.ietf.org/html/rfc3447#section-9.1.2.
200+
// The rest of this function is EMSA-PSS-VERIFY from
201+
// https://tools.ietf.org/html/rfc3447#section-9.1.2.
202202

203-
// Steps 1 and 2.
204-
let digest_len = self.digest_alg.output_len;
205-
let m_hash = digest::digest(self.digest_alg,
206-
msg.as_slice_less_safe());
203+
// Steps 1 and 2.
204+
let digest_len = self.digest_alg.output_len;
205+
let m_hash = digest::digest(self.digest_alg, msg.as_slice_less_safe());
207206

208-
// Step 3: where we assume the digest and salt are of equal length.
209-
if em_len < 2 + (2 * digest_len) {
210-
return Err(error::Unspecified);
211-
}
207+
// Step 3: where we assume the digest and salt are of equal length.
208+
if em_len < 2 + (2 * digest_len) {
209+
return Err(error::Unspecified);
210+
}
212211

213-
// Steps 4 and 5: Parse encoded message as:
214-
// masked_db || h_hash || 0xbc
215-
let db_len = em_len - digest_len - 1;
216-
let masked_db = try!(em.skip_and_get_input(db_len));
217-
let h_hash = try!(em.skip_and_get_input(digest_len));
218-
if try!(em.read_byte()) != 0xbc {
219-
return Err(error::Unspecified);
220-
}
212+
// Steps 4 and 5: Parse encoded message as:
213+
// masked_db || h_hash || 0xbc
214+
let db_len = em_len - digest_len - 1;
215+
let masked_db = try!(em.skip_and_get_input(db_len));
216+
let h_hash = try!(em.skip_and_get_input(digest_len));
217+
if try!(em.read_byte()) != 0xbc {
218+
return Err(error::Unspecified);
219+
}
221220

222-
// Step 7.
223-
let mut db = [0u8; super::PUBLIC_MODULUS_MAX_LEN / 8];
224-
let db = &mut db[..db_len];
221+
// Step 7.
222+
let mut db = [0u8; super::PUBLIC_MODULUS_MAX_LEN / 8];
223+
let db = &mut db[..db_len];
225224

226-
try!(mgf1(self.digest_alg, h_hash.as_slice_less_safe(), db));
225+
try!(mgf1(self.digest_alg, h_hash.as_slice_less_safe(), db));
227226

228-
try!(masked_db.read_all(error::Unspecified, |masked_bytes| {
229-
// Step 6. Check the top bits of first byte are zero.
230-
let b = try!(masked_bytes.read_byte());
231-
if b & !top_byte_mask != 0 {
232-
return Err(error::Unspecified);
233-
}
234-
db[0] ^= b;
235-
236-
// Step 8.
237-
for i in 1..db.len() {
238-
db[i] ^= try!(masked_bytes.read_byte());
239-
}
240-
Ok(())
241-
}));
242-
243-
// Step 9.
244-
db[0] &= top_byte_mask;
245-
246-
// Step 10.
247-
let pad_len = db.len() - digest_len - 1;
248-
for i in 0..pad_len {
249-
if db[i] != 0 {
250-
return Err(error::Unspecified);
251-
}
252-
}
253-
if db[pad_len] != 1 {
227+
try!(masked_db.read_all(error::Unspecified, |masked_bytes| {
228+
// Step 6. Check the top bits of first byte are zero.
229+
let b = try!(masked_bytes.read_byte());
230+
if b & !top_byte_mask != 0 {
254231
return Err(error::Unspecified);
255232
}
233+
db[0] ^= b;
256234

257-
// Step 11.
258-
let salt = &db[db.len() - digest_len..];
235+
// Step 8.
236+
for i in 1..db.len() {
237+
db[i] ^= try!(masked_bytes.read_byte());
238+
}
239+
Ok(())
240+
}));
259241

260-
// Step 12 and 13: compute hash value of:
261-
// (0x)00 00 00 00 00 00 00 00 || m_hash || salt
262-
let mut ctx = digest::Context::new(self.digest_alg);
263-
ctx.update(&PSS_PREFIX_ZEROS);
264-
ctx.update(m_hash.as_ref());
265-
ctx.update(salt);
266-
let h_hash_check = ctx.finish();
242+
// Step 9.
243+
db[0] &= top_byte_mask;
267244

268-
// Step 14.
269-
if h_hash != h_hash_check.as_ref() {
245+
// Step 10.
246+
let pad_len = db.len() - digest_len - 1;
247+
for i in 0..pad_len {
248+
if db[i] != 0 {
270249
return Err(error::Unspecified);
271250
}
272-
Ok(())
273-
})
251+
}
252+
if db[pad_len] != 1 {
253+
return Err(error::Unspecified);
254+
}
255+
256+
// Step 11.
257+
let salt = &db[db.len() - digest_len..];
258+
259+
// Step 12 and 13: compute hash value of:
260+
// (0x)00 00 00 00 00 00 00 00 || m_hash || salt
261+
let mut ctx = digest::Context::new(self.digest_alg);
262+
ctx.update(&PSS_PREFIX_ZEROS);
263+
ctx.update(m_hash.as_ref());
264+
ctx.update(salt);
265+
let h_hash_check = ctx.finish();
266+
267+
// Step 14.
268+
if h_hash != h_hash_check.as_ref() {
269+
return Err(error::Unspecified);
270+
}
271+
272+
Ok(())
274273
}
275274
}
276275

@@ -314,7 +313,7 @@ rsa_pss_padding!(RSA_PSS_SHA512, &digest::SHA512,
314313

315314
#[cfg(test)]
316315
mod test {
317-
use test;
316+
use {error, test};
318317
use super::*;
319318
use untrusted;
320319

@@ -341,7 +340,9 @@ mod test {
341340
let bit_len = test_case.consume_usize("Len");
342341
let expected_result = test_case.consume_string("Result");
343342

344-
let actual_result = alg.verify(msg, encoded, bit_len);
343+
let actual_result =
344+
encoded.read_all(error::Unspecified,
345+
|m| alg.verify(msg, m, bit_len));
345346
assert_eq!(actual_result.is_ok(), expected_result == "P");
346347

347348
Ok(())

src/rsa/verification.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,9 @@ pub fn verify_rsa(params: &RSAParameters,
124124
params.min_bits, PUBLIC_MODULUS_MAX_LEN)
125125
}));
126126

127-
params.padding_alg.verify(msg, untrusted::Input::from(decoded),
128-
n.length_in_bits())
127+
untrusted::Input::from(decoded).read_all(
128+
error::Unspecified,
129+
|m| params.padding_alg.verify(msg, m, n.length_in_bits()))
129130
}
130131

131132
extern {

0 commit comments

Comments
 (0)