Skip to content

Commit

Permalink
get_file() with tar, tgz, tar.bz, zip and sha256, resolves keras-team…
Browse files Browse the repository at this point in the history
…#5861. (keras-team#5882)

* get_file() with tar, tgz, tar.bz, zip and sha256, resolves keras-team#5861.

The changes were designed to preserve backwards compatibility while adding support
for .tar.gz, .tgz, .tar.bz, and .zip files.
sha256 hash is now supported in addition to md5.

* get_file() improve large file performance keras-team#5861.

* getfile() extract parameter fix (keras-team#5861)

* extract_archive() py3 fix (keras-team#5861)

* get_file() tarfile fix (keras-team#5861)

* data_utils.py and data_utils_test.py updated based on review (keras-team#5861)
# This is a combination of 4 commits.
# The first commit's message is:
get_file() with tar, tgz, tar.bz, zip and sha256, resolves keras-team#5861.

The changes were designed to preserve backwards compatibility while adding support
for .tar.gz, .tgz, .tar.bz, and .zip files.
Adds extract_archive() and hash_file() functions.
sha256 hash is now supported in addition to md5.
adds data_utils_test.py to test new functionality

# This is the 2nd commit message:

extract_archive() redundant open (keras-team#5861)

# This is the 3rd commit message:

data_utils.py and data_utils_test.py updated based on review (keras-team#5861)
test creates its own tiny file to download and extract locally.
test covers md5 sha256 zip and tar
_hash_file() now private
_extract_archive() now private

# This is the 4th commit message:

data_utils.py and data_utils_test.py updated based on review (keras-team#5861)
test creates its own tiny file to download and extract locally.
test covers md5 sha256 zip and tar
_hash_file() now private
_extract_archive() now private

* data_utils.py and data_utils_test.py updated based on review (keras-team#5861)

* data_utils.py get_file() cache_dir docs (keras-team#5861)

* data_utils.py address docs comments (keras-team#5861)

* get_file() comment link, path, & typo fix
  • Loading branch information
ahundt authored and fchollet committed Apr 4, 2017
1 parent 64d2421 commit 4fe78f3
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 32 deletions.
179 changes: 147 additions & 32 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import functools
import tarfile
import zipfile
import os
import sys
import shutil
import hashlib
import six
from six.moves.urllib.request import urlopen
from six.moves.urllib.error import URLError
from six.moves.urllib.error import HTTPError
Expand Down Expand Up @@ -55,24 +57,105 @@ def chunk_read(response, chunk_size=8192, reporthook=None):
from six.moves.urllib.request import urlretrieve


def _extract_archive(file_path, path='.', archive_format='auto'):
"""Extracts an archive if it matches the tar, tar.gz, tar.bz, or zip formats
# Arguments
file_path: path to the archive file
path: path to extract the archive file
archive_format: Archive format to try for extracting the file.
Options are 'auto', 'tar', 'zip', and None.
'tar' includes tar, tar.gz, and tar.bz files.
The default 'auto' is ['tar', 'zip'].
None or an empty list will return no matches found.
# Return:
True if a match was found and an archive extraction was completed,
False otherwise.
"""
if archive_format is None:
return False
if archive_format is 'auto':
archive_format = ['tar', 'zip']
if isinstance(archive_format, six.string_types):
archive_format = [archive_format]

for archive_type in archive_format:
if archive_type is 'tar':
open_fn = tarfile.open
is_match_fn = tarfile.is_tarfile
if archive_type is 'zip':
open_fn = zipfile.ZipFile
is_match_fn = zipfile.is_zipfile

if is_match_fn(file_path):
with open_fn(file_path) as archive:
try:
archive.extractall(path)
except (tarfile.TarError, RuntimeError,
KeyboardInterrupt) as e:
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)
raise
return True
return False


def get_file(fname, origin, untar=False,
md5_hash=None, cache_subdir='datasets'):
md5_hash=None, cache_subdir='datasets',
file_hash=None,
hash_algorithm='auto',
extract=False,
archive_format='auto',
cache_dir=None):
"""Downloads a file from a URL if it not already in the cache.
Passing the MD5 hash will verify the file after download
as well as if it is already present in the cache.
By default the file at the url `origin` is downloaded to the
cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
and given the filename `fname`. The final location of a file
`example.txt` would therefore be `~/.keras/datasets/example.txt`.
Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
Passing a hash will verify the file after download. The command line
programs `shasum` and `sha256sum` can compute the hash.
# Arguments
fname: name of the file
origin: original URL of the file
untar: boolean, whether the file should be decompressed
md5_hash: MD5 hash of the file for verification
cache_subdir: directory being used as the cache
fname: Name of the file. If an absolute path `/path/to/file.txt` is
specified the file will be saved at that location.
origin: Original URL of the file.
untar: Deprecated in favor of 'extract'.
boolean, whether the file should be decompressed
md5_hash: Deprecated in favor of 'file_hash'.
md5 hash of the file for verification
file_hash: The expected hash string of the file after download.
The sha256 and md5 hash algorithms are both supported.
cache_subdir: Subdirectory under the Keras cache dir where the file is
saved. If an absolute path `/path/to/folder` is
specified the file will be saved at that location.
hash_algorithm: Select the hash algorithm to verify the file.
options are 'md5', 'sha256', and 'auto'.
The default 'auto' detects the hash algorithm in use.
extract: True tries extracting the file as an Archive, like tar or zip.
archive_format: Archive format to try for extracting the file.
Options are 'auto', 'tar', 'zip', and None.
'tar' includes tar, tar.gz, and tar.bz files.
The default 'auto' is ['tar', 'zip'].
None or an empty list will return no matches found.
cache_dir: Location to store cached files, when None it
defaults to the [Keras Directory](/faq/#where-is-the-keras-configuration-filed-stored).
# Returns
Path to the downloaded file
"""
datadir_base = os.path.expanduser(os.path.join('~', '.keras'))
if cache_dir is None:
cache_dir = os.path.expanduser(os.path.join('~', '.keras'))
if md5_hash is not None and file_hash is None:
file_hash = md5_hash
hash_algorithm = 'md5'
datadir_base = os.path.expanduser(cache_dir)
if not os.access(datadir_base, os.W_OK):
datadir_base = os.path.join('/tmp', '.keras')
datadir = os.path.join(datadir_base, cache_subdir)
Expand All @@ -88,10 +171,12 @@ def get_file(fname, origin, untar=False,
download = False
if os.path.exists(fpath):
# File found; verify integrity if a hash was provided.
if md5_hash is not None:
if not validate_file(fpath, md5_hash):
if file_hash is not None:
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
print('A local file was found, but it seems to be '
'incomplete or outdated.')
'incomplete or outdated because the ' + hash_algorithm +
' file hash does not match the original value of ' +
file_hash + ' so we will re-download the data.')
download = True
else:
download = True
Expand Down Expand Up @@ -123,38 +208,68 @@ def dl_progress(count, block_size, total_size, progbar=None):

if untar:
if not os.path.exists(untar_fpath):
print('Untaring file...')
tfile = tarfile.open(fpath, 'r:gz')
try:
tfile.extractall(path=datadir)
except (Exception, KeyboardInterrupt) as e:
if os.path.exists(untar_fpath):
if os.path.isfile(untar_fpath):
os.remove(untar_fpath)
else:
shutil.rmtree(untar_fpath)
raise
tfile.close()
_extract_archive(fpath, datadir, archive_format='tar')
return untar_fpath

if extract:
_extract_archive(fpath, datadir, archive_format)

return fpath


def validate_file(fpath, md5_hash):
"""Validates a file against a MD5 hash.
def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
"""Calculates a file sha256 or md5 hash.
# Example
```python
>>> from keras.data_utils import _hash_file
>>> _hash_file('/path/to/file.zip')
'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
```
# Arguments
fpath: path to the file being validated
md5_hash: the MD5 hash being validated against
algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
The default 'auto' detects the hash algorithm in use.
chunk_size: Bytes to read at a time, important for large files.
# Returns
The file hash
"""
if (algorithm is 'sha256') or (algorithm is 'auto' and len(hash) is 64):
hasher = hashlib.sha256()
else:
hasher = hashlib.md5()

with open(fpath, 'rb') as fpath_file:
for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
hasher.update(chunk)

return hasher.hexdigest()


def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
"""Validates a file against a sha256 or md5 hash.
# Arguments
fpath: path to the file being validated
file_hash: The expected hash string of the file.
The sha256 and md5 hash algorithms are both supported.
algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
The default 'auto' detects the hash algorithm in use.
chunk_size: Bytes to read at a time, important for large files.
# Returns
Whether the file is valid
"""
hasher = hashlib.md5()
with open(fpath, 'rb') as f:
buf = f.read()
hasher.update(buf)
if str(hasher.hexdigest()) == str(md5_hash):
if ((algorithm is 'sha256') or
(algorithm is 'auto' and len(file_hash) is 64)):
hasher = 'sha256'
else:
hasher = 'md5'

if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
return True
else:
return False
59 changes: 59 additions & 0 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Tests for functions in data_utils.py.
"""
import os
import pytest
import tarfile
import zipfile
from six.moves.urllib.request import pathname2url
from six.moves.urllib.parse import urljoin
from keras.utils.data_utils import get_file
from keras.utils.data_utils import validate_file
from keras.utils.data_utils import _hash_file
from keras import activations
from keras import regularizers


def test_data_utils():
"""Tests get_file from a url, plus extraction and validation.
"""
dirname = 'data_utils'

with open('test.txt', 'w') as text_file:
text_file.write('Float like a butterfly, sting like a bee.')

with tarfile.open('test.tar.gz', 'w:gz') as tar_file:
tar_file.add('test.txt')

with zipfile.ZipFile('test.zip', 'w') as zip_file:
zip_file.write('test.txt')

origin = urljoin('file://', pathname2url(os.path.abspath('test.tar.gz')))

path = get_file(dirname, origin, untar=True)
filepath = path + '.tar.gz'
hashval_sha256 = _hash_file(filepath)
hashval_md5 = _hash_file(filepath, algorithm='md5')
path = get_file(dirname, origin, md5_hash=hashval_md5, untar=True)
path = get_file(filepath, origin, file_hash=hashval_sha256, extract=True)
assert os.path.exists(filepath)
assert validate_file(filepath, hashval_sha256)
assert validate_file(filepath, hashval_md5)
os.remove(filepath)
os.remove('test.tar.gz')

origin = urljoin('file://', pathname2url(os.path.abspath('test.zip')))

hashval_sha256 = _hash_file('test.zip')
hashval_md5 = _hash_file('test.zip', algorithm='md5')
path = get_file(dirname, origin, md5_hash=hashval_md5, extract=True)
path = get_file(dirname, origin, file_hash=hashval_sha256, extract=True)
assert os.path.exists(path)
assert validate_file(path, hashval_sha256)
assert validate_file(path, hashval_md5)

os.remove(path)
os.remove('test.txt')
os.remove('test.zip')

if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 4fe78f3

Please sign in to comment.