forked from Hvass-Labs/TensorFlow-Tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
download.py
98 lines (76 loc) · 3.09 KB
/
download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
########################################################################
#
# Functions for downloading and extracting data-files from the internet.
#
# Implemented in Python 3.5
#
########################################################################
#
# This file is part of the TensorFlow Tutorials available at:
#
# https://github.com/Hvass-Labs/TensorFlow-Tutorials
#
# Published under the MIT License. See the file LICENSE for details.
#
# Copyright 2016 by Magnus Erik Hvass Pedersen
#
########################################################################
import sys
import os
import urllib.request
import tarfile
import zipfile
########################################################################
def _print_download_progress(count, block_size, total_size):
"""
Function used for printing the download progress.
Used as a call-back function in maybe_download_and_extract().
"""
# Percentage completion.
pct_complete = float(count * block_size) / total_size
# Status-message. Note the \r which means the line should overwrite itself.
msg = "\r- Download progress: {0:.1%}".format(pct_complete)
# Print it.
sys.stdout.write(msg)
sys.stdout.flush()
########################################################################
def maybe_download_and_extract(url, download_dir):
"""
Download and extract the data if it doesn't already exist.
Assumes the url is a tar-ball file.
:param url:
Internet URL for the tar-file to download.
Example: "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
:param download_dir:
Directory where the downloaded file is saved.
Example: "data/CIFAR-10/"
:return:
Nothing.
"""
# Filename for saving the file downloaded from the internet.
# Use the filename from the URL and add it to the download_dir.
filename = url.split('/')[-1]
file_path = os.path.join(download_dir, filename)
# Check if the file already exists.
# If it exists then we assume it has also been extracted,
# otherwise we need to download and extract it now.
if not os.path.exists(file_path):
# Check if the download directory exists, otherwise create it.
if not os.path.exists(download_dir):
os.makedirs(download_dir)
# Download the file from the internet.
file_path, _ = urllib.request.urlretrieve(url=url,
filename=file_path,
reporthook=_print_download_progress)
print()
print("Download finished. Extracting files.")
if file_path.endswith(".zip"):
# Unpack the zip-file.
zipfile.ZipFile(file=file_path, mode="r").extractall(download_dir)
elif file_path.endswith((".tar.gz", ".tgz")):
# Unpack the tar-ball.
tarfile.open(name=file_path, mode="r:gz").extractall(download_dir)
print("Done.")
else:
print("Data has apparently already been downloaded and unpacked.")
########################################################################