Skip to content

Commit

Permalink
Add Python Streaming AEAD wrapper.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 327794731
  • Loading branch information
juergw authored and copybara-github committed Aug 21, 2020
1 parent 866270b commit e9052d9
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 39 deletions.
19 changes: 19 additions & 0 deletions python/tink/streaming_aead/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ py_test(
":_raw_streaming_aead",
":streaming_aead",
requirement("absl-py"),
"//tink:cleartext_keyset_handle",
"//tink:tink_python",
"//tink/cc/pybind:tink_bindings",
"//tink/core",
"//tink/proto:aes_ctr_hmac_streaming_py_pb2",
Expand Down Expand Up @@ -161,11 +163,28 @@ py_library(
srcs_version = "PY3",
deps = [
":_raw_streaming_aead",
":_rewindable_input_stream",
":_streaming_aead",
"//tink/core",
],
)

py_test(
name = "_streaming_aead_wrapper_test",
timeout = "short",
srcs = ["_streaming_aead_wrapper_test.py"],
srcs_version = "PY3",
deps = [
":streaming_aead",
requirement("absl-py"),
"//tink/core",
"//tink/proto:tink_py_pb2",
"//tink/testing:bytes_io",
"//tink/testing:fake_streaming_aead",
"//tink/testing:helper",
],
)

py_library(
name = "_rewindable_input_stream",
srcs = ["_rewindable_input_stream.py"],
Expand Down
12 changes: 8 additions & 4 deletions python/tink/streaming_aead/_rewindable_input_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class RewindableInputStream(io.RawIOBase):
"""Implements a readable io.RawIOBase wrapper that supports rewinding.
The wrapped input_stream can either be a io.RawIOBase or io.BufferedIOBase.
input_stream.read may return None on some calls, but it is required to
eventually return some data, or return b'' if EOF is reached.
"""

def __init__(self, input_stream: BinaryIO):
Expand All @@ -38,6 +40,11 @@ def __init__(self, input_stream: BinaryIO):
def read(self, size: int = -1) -> Optional[bytes]:
"""Read and return up to size bytes when size >= 0.
This function may return None on some calls, but it will eventually return
some data, or return b'' if EOF is reached. Since all data is buffered when
the stream is still rewindable, it is also guaranteed that None will not be
returned on previously read data.
Args:
size: Maximum number of bytes to be returned, if >= 0. If size is smaller
than 0 or None, return the whole content of the file.
Expand Down Expand Up @@ -66,17 +73,14 @@ def read(self, size: int = -1) -> Optional[bytes]:
if data is None:
# self._input_stream is a RawIOBase and has currently no data
return None
if self._rewindable and not self._input_stream.seekable():
if self._rewindable:
self._buffer.extend(data)
self._pos += len(data)
return data

def rewind(self):
if not self._rewindable:
raise ValueError('rewind is disabled')
if self._input_stream.seekable():
self._input_stream.seek(0)
return
self._pos = 0

def disable_rewind(self):
Expand Down
28 changes: 6 additions & 22 deletions python/tink/streaming_aead/_rewindable_input_stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,16 @@ def test_readall(self, seekable):
with _rewindable(b'The quick brown fox', seekable) as f:
self.assertEqual(b'The quick brown fox', f.readall())

def test_rewind_read_non_seekable(self):
with _rewindable(b'The quick brown fox', seekable=False) as f:
@parameterized.parameters([False, True])
def test_rewind_read(self, seekable):
with _rewindable(b'The quick brown fox', seekable) as f:
self.assertEqual(b'The quick', f.read(9))
f.rewind()
self.assertEqual(b'The ', f.read(4))
# this only reads the rest of current buffer content.
self.assertEqual(b'quick', f.read(100))
self.assertEqual(b' brown fox', f.read())

def test_rewind_read_seekable(self):
with _rewindable(b'The quick brown fox', seekable=True) as f:
self.assertEqual(b'The quick', f.read(9))
f.rewind()
self.assertEqual(b'The ', f.read(4))
# no buffering, so this reads the rest.
self.assertEqual(b'quick brown fox', f.read(100))
self.assertEqual(b'', f.read())

@parameterized.parameters([False, True])
def test_rewind_readall(self, seekable):
with _rewindable(b'The quick brown fox', seekable) as f:
Expand All @@ -103,8 +95,9 @@ def test_rewind_twice(self, seekable):
f.rewind()
self.assertEqual(b'The quick brown fox', f.read())

def test_disable_rewind_non_seekable(self):
with _rewindable(b'The quick brown fox', seekable=False) as f:
@parameterized.parameters([False, True])
def test_disable_rewind(self, seekable):
with _rewindable(b'The quick brown fox', seekable) as f:
self.assertEqual(b'The q', f.read(5))
f.rewind()
f.disable_rewind()
Expand All @@ -114,15 +107,6 @@ def test_disable_rewind_non_seekable(self):
self.assertEmpty(f._buffer)
self.assertEqual(b'ick brown fox', f.read())

def test_disable_rewind_seekable(self):
with _rewindable(b'The quick brown fox', seekable=True) as f:
self.assertEqual(b'The q', f.read(5))
f.rewind()
f.disable_rewind()
# no buffering, so this reads everything
self.assertEqual(b'The quick brown fox', f.read(100))
self.assertEqual(b'', f.read())

@parameterized.parameters([False, True])
def test_disable_rewind_readall(self, seekable):
with _rewindable(b'The quick brown fox', seekable) as f:
Expand Down
67 changes: 62 additions & 5 deletions python/tink/streaming_aead/_streaming_aead_key_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
from typing import BinaryIO, cast

from absl.testing import absltest
from absl.testing import parameterized
from tink.proto import aes_ctr_hmac_streaming_pb2
from tink.proto import aes_gcm_hkdf_streaming_pb2
from tink.proto import common_pb2
from tink.proto import tink_pb2
import tink
from tink import cleartext_keyset_handle
from tink import core
from tink import streaming_aead
from tink.streaming_aead import _raw_streaming_aead
Expand All @@ -38,7 +41,7 @@ def setUpModule():
streaming_aead.register()


class StreamingAeadKeyManagerTest(absltest.TestCase):
class StreamingAeadKeyManagerTest(parameterized.TestCase):

def setUp(self):
super(StreamingAeadKeyManagerTest, self).setUp()
Expand Down Expand Up @@ -103,7 +106,7 @@ def test_invalid_aes_ctr_hmac_params_throw_exception(self):
'key_size must not be smaller than'):
self.key_manager_ctr.new_key_data(key_template)

def test_encrypt_decrypt(self):
def test_raw_encrypt_decrypt(self):
raw_primitive = self.key_manager_ctr.primitive(
self.key_manager_ctr.new_key_data(
streaming_aead.streaming_aead_key_templates
Expand All @@ -128,7 +131,7 @@ def test_encrypt_decrypt(self):
self.assertEqual(ct_source.closed, close_ciphertext_source)
self.assertEqual(output, plaintext)

def test_read_after_eof_returns_empty_bytes(self):
def test_raw_read_after_eof_returns_empty_bytes(self):
raw_primitive = self.key_manager_ctr.primitive(
self.key_manager_ctr.new_key_data(
streaming_aead.streaming_aead_key_templates
Expand All @@ -146,7 +149,7 @@ def test_read_after_eof_returns_empty_bytes(self):
_ = ds.readall()
self.assertEqual(ds.read(100), b'')

def test_encrypt_decrypt_tempfile(self):
def test_raw_encrypt_decrypt_tempfile(self):
raw_primitive = self.key_manager_ctr.primitive(
self.key_manager_ctr.new_key_data(
streaming_aead.streaming_aead_key_templates
Expand All @@ -170,7 +173,7 @@ def test_encrypt_decrypt_tempfile(self):
os.unlink(encryptedfile_name)
self.assertEqual(output, plaintext)

def test_encrypt_decrypt_wrong_aad(self):
def test_raw_encrypt_decrypt_wrong_aad(self):
raw_primitive = self.key_manager_ctr.primitive(
self.key_manager_ctr.new_key_data(
streaming_aead.streaming_aead_key_templates
Expand All @@ -191,6 +194,60 @@ def test_encrypt_decrypt_wrong_aad(self):
with self.assertRaises(core.TinkError):
ds.read()

@parameterized.parameters(
[io.BytesIO, bytes_io.SlowBytesIO, bytes_io.SlowReadableRawBytes])
def test_wrapped_encrypt_decrypt_two_keys(self, input_stream_factory):
template = (
streaming_aead.streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB)
old_keyset = tink_pb2.Keyset()
key1 = old_keyset.key.add()
key1.key_data.CopyFrom(tink.core.Registry.new_key_data(template))
key1.status = tink_pb2.ENABLED
key1.key_id = 1234
key1.output_prefix_type = template.output_prefix_type
old_keyset.primary_key_id = key1.key_id
old_keyset_handle = cleartext_keyset_handle.from_keyset(old_keyset)
old_primitive = old_keyset_handle.primitive(streaming_aead.StreamingAead)

new_keyset = tink_pb2.Keyset()
new_keyset.CopyFrom(old_keyset)
key2 = new_keyset.key.add()
key2.key_data.CopyFrom(tink.core.Registry.new_key_data(template))
key2.status = tink_pb2.ENABLED
key2.key_id = 5678
key2.output_prefix_type = template.output_prefix_type
new_keyset.primary_key_id = key2.key_id
new_keyset_handle = cleartext_keyset_handle.from_keyset(new_keyset)
new_primitive = new_keyset_handle.primitive(streaming_aead.StreamingAead)

plaintext1 = b' '.join(b'%d' % i for i in range(100 * 1000))
ciphertext1_dest = bytes_io.BytesIOWithValueAfterClose()
with old_primitive.new_encrypting_stream(ciphertext1_dest, b'aad1') as es:
es.write(plaintext1)
ciphertext1 = ciphertext1_dest.value_after_close()

plaintext2 = b' '.join(b'%d' % i for i in range(100 * 1001))
ciphertext2_dest = bytes_io.BytesIOWithValueAfterClose()
with new_primitive.new_encrypting_stream(ciphertext2_dest, b'aad2') as es:
es.write(plaintext2)
ciphertext2 = ciphertext2_dest.value_after_close()

# old_primitive can read 1st ciphertext, but not the 2nd
with old_primitive.new_decrypting_stream(
cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds:
self.assertEqual(ds.read(), plaintext1)
with old_primitive.new_decrypting_stream(
cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds:
with self.assertRaises(tink.TinkError):
ds.read()

# new_primitive can read both ciphertexts
with new_primitive.new_decrypting_stream(
cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds:
self.assertEqual(ds.read(), plaintext1)
with new_primitive.new_decrypting_stream(
cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds:
self.assertEqual(ds.read(), plaintext2)

if __name__ == '__main__':
absltest.main()
110 changes: 104 additions & 6 deletions python/tink/streaming_aead/_streaming_aead_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,115 @@
from __future__ import print_function

import io
from typing import cast, BinaryIO, Type
from typing import cast, BinaryIO, Optional, Type

from tink import core
from tink.streaming_aead import _raw_streaming_aead
from tink.streaming_aead import _rewindable_input_stream
from tink.streaming_aead import _streaming_aead


class _DecryptingStreamWrapper(io.RawIOBase):
"""A file-like object which decrypts reads from an underlying object.
It uses a primitive set of streaming AEADs, and decrypts the stream with the
matching key in the keyset. Closing this wrapper also closes
ciphertext_source. Currently, only seekable ciphertext_source are supported.
"""

def __init__(self, primitive_set: core.PrimitiveSet,
ciphertext_source: BinaryIO, associated_data: bytes):
"""Create a new _DecryptingStreamWrapper.
Args:
primitive_set: The primitive set of StreamingAead primitives.
ciphertext_source: A readable file-like object from which ciphertext bytes
will be read.
associated_data: The associated data to use for decryption.
"""
super(_DecryptingStreamWrapper, self).__init__()
if not ciphertext_source.readable():
raise ValueError('ciphertext_source must be readable')
self._ciphertext_source = _rewindable_input_stream.RewindableInputStream(
ciphertext_source)
self._associated_data = associated_data
self._matching_stream = None
self._primitive_set = primitive_set

def read(self, size=-1) -> Optional[bytes]:
"""Read and return up to size bytes, where size is an int.
Args:
size: Maximum number of bytes to read. As a convenience, if size is
unspecified or -1, all bytes until EOF are returned.
Returns:
Bytes read. An empty bytes object is returned if the stream is already at
EOF. None is returned if no data is available at the moment.
Raises:
TinkError if there was a permanent error.
ValueError if the file is closed.
"""
if self.closed: # pylint:disable=using-constant-test
raise ValueError('read on closed file.')
if size == 0:
return bytes()
if self._matching_stream:
return self._matching_stream.read(size)
# if self._matching_stream is not set, no data has been read successfully
# and self._ciphertext_source is at the beginning.
for entry in self._primitive_set.raw_primitives():
try:
# ciphertext_source should never be closed by any of the raw decrypting
# streams. It will be closed in close(), and only there.
attempted_stream = entry.primitive.new_raw_decrypting_stream(
self._ciphertext_source,
self._associated_data,
close_ciphertext_source=False)
data = attempted_stream.read(size)
if data is None:
# No data at the moment. Not clear if decryption was successful.
# Try again.
# To not end up in an infinite loop, we need self._ciphertext_source
# to make progress, even if rewind() is called inbetween calls to
# read().
self._ciphertext_source.rewind()
return None
# Any value other than None means that decryption was successful.
# (b'' indicates that the plaintext is an empty string.)
self._matching_stream = attempted_stream
self._ciphertext_source.disable_rewind()
return data
except core.TinkError:
# Try another key.
self._ciphertext_source.rewind()
raise core.TinkError(
'No matching key found for the ciphertext in the stream')

def readinto(self, b: bytearray) -> Optional[int]:
"""Read bytes into a pre-allocated bytes-like object b."""
data = self.read(len(b))
if data is None:
return None
n = len(data)
b[:n] = data
return n

def close(self) -> None:
if self.closed: # pylint:disable=using-constant-test
return
if self._matching_stream:
self._matching_stream.close()
self._ciphertext_source.close()
super(_DecryptingStreamWrapper, self).close()

def readable(self) -> bool:
return True


class _WrappedStreamingAead(_streaming_aead.StreamingAead):
"""_WrappedStreamingAead."""
"""Implements StreamingAead by wrapping a set of RawStreamingAead."""

def __init__(self, primitives_set: core.PrimitiveSet):
self._primitive_set = primitives_set
Expand All @@ -38,10 +138,8 @@ def new_encrypting_stream(self, ciphertext_destination: BinaryIO,

def new_decrypting_stream(self, ciphertext_source: BinaryIO,
associated_data: bytes) -> BinaryIO:
# TODO(juerg): Implement a proper wrapper.
# This implementation only works for keysets with a single key!
raw = self._primitive_set.primary().primitive.new_raw_decrypting_stream(
ciphertext_source, associated_data, close_ciphertext_source=True)
raw = _DecryptingStreamWrapper(self._primitive_set, ciphertext_source,
associated_data)
return cast(BinaryIO, io.BufferedReader(raw))


Expand Down
Loading

0 comments on commit e9052d9

Please sign in to comment.