Skip to content

Commit

Permalink
Add support for downloading files to file pointer, fix for https://gi…
Browse files Browse the repository at this point in the history
  • Loading branch information
OctoNezd authored and delivrance committed Jul 8, 2020
1 parent 55d0b93 commit 1e8c981
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 85 deletions.
173 changes: 91 additions & 82 deletions pyrogram/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

import io
import logging
import math
import os
Expand Down Expand Up @@ -1231,9 +1231,9 @@ def download_worker(self):

temp_file_path = ""
final_file_path = ""

path = [None]
try:
data, directory, file_name, done, progress, progress_args, path = packet
data, done, progress, progress_args, out, path, to_file = packet

temp_file_path = self.get_file(
media_type=data.media_type,
Expand All @@ -1250,13 +1250,15 @@ def download_worker(self):
file_size=data.file_size,
is_big=data.is_big,
progress=progress,
progress_args=progress_args
progress_args=progress_args,
out=out
)

if temp_file_path:
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
os.makedirs(directory, exist_ok=True)
shutil.move(temp_file_path, final_file_path)
if to_file:
final_file_path = out.name
else:
final_file_path = ''
if to_file:
out.close()
except Exception as e:
log.error(e, exc_info=True)

Expand Down Expand Up @@ -1864,7 +1866,8 @@ def get_file(
file_size: int,
is_big: bool,
progress: callable,
progress_args: tuple = ()
progress_args: tuple = (),
out: io.IOBase = None
) -> str:
with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
Expand Down Expand Up @@ -1950,7 +1953,10 @@ def get_file(
limit = 1024 * 1024
offset = 0
file_name = ""

if not out:
f = tempfile.NamedTemporaryFile("wb", delete=False)
else:
f = out
try:
r = session.send(
functions.upload.GetFile(
Expand All @@ -1961,36 +1967,37 @@ def get_file(
)

if isinstance(r, types.upload.File):
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
if hasattr(f, "name"):
file_name = f.name

while True:
chunk = r.bytes
while True:
chunk = r.bytes

if not chunk:
break
if not chunk:
break

f.write(chunk)
f.write(chunk)

offset += limit
offset += limit

if progress:
progress(
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
if progress:
progress(

r = session.send(
functions.upload.GetFile(
location=location,
offset=offset,
limit=limit
)
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)

r = session.send(
functions.upload.GetFile(
location=location,
offset=offset,
limit=limit
)
)

elif isinstance(r, types.upload.FileCdnRedirect):
with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
Expand All @@ -2003,78 +2010,80 @@ def get_file(
self.media_sessions[r.dc_id] = cdn_session

try:
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
if hasattr(f, "name"):
file_name = f.name

while True:
r2 = cdn_session.send(
functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset,
limit=limit
)
while True:
r2 = cdn_session.send(
functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset,
limit=limit
)
)

if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
except VolumeLocNotFound:
break
else:
continue
)
except VolumeLocNotFound:
break
else:
continue

chunk = r2.bytes
chunk = r2.bytes

# https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = AES.ctr256_decrypt(
chunk,
r.encryption_key,
bytearray(
r.encryption_iv[:-4]
+ (offset // 16).to_bytes(4, "big")
)
# https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = AES.ctr256_decrypt(
chunk,
r.encryption_key,
bytearray(
r.encryption_iv[:-4]
+ (offset // 16).to_bytes(4, "big")
)
)

hashes = session.send(
functions.upload.GetCdnFileHashes(
file_token=r.file_token,
offset=offset
)
hashes = session.send(
functions.upload.GetCdnFileHashes(
file_token=r.file_token,
offset=offset
)
)

# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)

f.write(decrypted_chunk)
f.write(decrypted_chunk)

offset += limit
offset += limit

if progress:
progress(
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
if progress:
progress(

if len(chunk) < limit:
break
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)

if len(chunk) < limit:
break
except Exception as e:
raise e
except Exception as e:
if not isinstance(e, Client.StopTransmission):
log.error(e, exc_info=True)

try:
os.remove(file_name)
if out:
os.remove(file_name)
except OSError:
pass

Expand Down
13 changes: 13 additions & 0 deletions pyrogram/client/methods/messages/download_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

import binascii
import io
import os
import re
import struct
import time
from datetime import datetime
Expand All @@ -37,6 +39,7 @@ def download_media(
message: Union["pyrogram.Message", str],
file_ref: str = None,
file_name: str = DEFAULT_DOWNLOAD_DIR,
out: io.IOBase = None,
block: bool = True,
progress: callable = None,
progress_args: tuple = ()
Expand All @@ -58,6 +61,9 @@ def download_media(
You can also specify a path for downloading files in a custom location: paths that end with "/"
are considered directories. All non-existent folders will be created automatically.
out (``io.IOBase``, *optional*):
A custom *file-like object* to be used when downloading file. Overrides file_name
block (``bool``, *optional*):
Blocks the code execution until the file has been downloaded.
Defaults to True.
Expand Down Expand Up @@ -238,6 +244,13 @@ def get_existing_attributes() -> dict:
extension
)

if not out:
out = open(os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))), 'wb')
os.makedirs(directory, exist_ok=True)
to_file = True
else:
to_file = False
self.download_queue.put((data, done, progress, progress_args, out, path, to_file))
# Cast to string because Path objects aren't supported by Python 3.5
self.download_queue.put((data, str(directory), str(file_name), done, progress, progress_args, path))

Expand Down
11 changes: 8 additions & 3 deletions pyrogram/client/types/messages_and_media/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

import io
from functools import partial
from typing import List, Match, Union

Expand Down Expand Up @@ -2964,7 +2964,7 @@ def retract_vote(
chat_id=message.chat.id,
message_id=message_id,
)
Example:
.. code-block:: python
Expand All @@ -2985,6 +2985,7 @@ def retract_vote(
def download(
self,
file_name: str = "",
out: io.IOBase = None,
block: bool = True,
progress: callable = None,
progress_args: tuple = ()
Expand All @@ -3009,6 +3010,9 @@ def download(
You can also specify a path for downloading files in a custom location: paths that end with "/"
are considered directories. All non-existent folders will be created automatically.
out (``io.IOBase``, *optional*):
A custom *file-like object* to be used when downloading file. Overrides file_name
block (``bool``, *optional*):
Blocks the code execution until the file has been downloaded.
Defaults to True.
Expand Down Expand Up @@ -3045,6 +3049,7 @@ def download(
return self._client.download_media(
message=self,
file_name=file_name,
out=out,
block=block,
progress=progress,
progress_args=progress_args,
Expand Down Expand Up @@ -3074,7 +3079,7 @@ def vote(
Parameters:
option (``int``):
Index of the poll option you want to vote for (0 to 9).
Returns:
:obj:`Poll`: On success, the poll with the chosen option is returned.
Expand Down

0 comments on commit 1e8c981

Please sign in to comment.