Skip to content

Commit

Permalink
perf: replay part file download
Browse files Browse the repository at this point in the history
  • Loading branch information
LeeEirc authored and BaiJiangJie committed Sep 14, 2024
1 parent 7da8224 commit 134f1a4
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 42 deletions.
35 changes: 21 additions & 14 deletions apps/common/storage/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
import os

import jms_storage

from django.conf import settings
from django.core.files.storage import default_storage

from terminal.models import default_storage, ReplayStorage
from common.utils import get_logger, make_dirs
from terminal.models import ReplayStorage

logger = get_logger(__name__)


def get_multi_object_storage():
replay_storages = ReplayStorage.objects.all()
configs = {}
for storage in replay_storages:
if storage.type_sftp:
continue
if storage.type_null_or_server:
continue
configs[storage.name] = storage.config
if settings.SERVER_REPLAY_STORAGE:
configs['SERVER_REPLAY_STORAGE'] = settings.SERVER_REPLAY_STORAGE
if not configs:
return None
storage = jms_storage.get_multi_object_storage(configs)
return storage


class BaseStorageHandler(object):
NAME = ''

Expand All @@ -24,20 +41,10 @@ def find_local(self):
raise NotImplementedError

def download(self):
replay_storages = ReplayStorage.objects.all()
configs = {}
for storage in replay_storages:
if storage.type_sftp:
continue
if storage.type_null_or_server:
continue
configs[storage.name] = storage.config
if settings.SERVER_REPLAY_STORAGE:
configs['SERVER_REPLAY_STORAGE'] = settings.SERVER_REPLAY_STORAGE
if not configs:
storage = get_multi_object_storage()
if not storage:
msg = f"Not found {self.NAME} file, and not remote storage set"
return None, msg
storage = jms_storage.get_multi_object_storage(configs)

remote_path, local_path = self.get_file_path(storage=storage)
if not remote_path:
Expand Down
83 changes: 81 additions & 2 deletions apps/common/storage/replay.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import json
import os
import tarfile
from itertools import chain

from terminal.models import default_storage
from .base import BaseStorageHandler
from django.core.files.storage import default_storage

from common.utils import make_dirs, get_logger
from terminal.models import Session
from .base import BaseStorageHandler, get_multi_object_storage

logger = get_logger(__name__)


class ReplayStorageHandler(BaseStorageHandler):
Expand Down Expand Up @@ -29,3 +37,74 @@ def find_local(self):
url = default_storage.url(_local_path)
return _local_path, url
return None, f'{self.NAME} not found.'


class SessionPartReplayStorageHandler(object):
Name = 'SessionPartReplayStorageHandler'

def __init__(self, obj: Session):
self.obj = obj

def find_local_part_file_path(self, part_filename):
local_path = self.obj.get_replay_part_file_local_storage_path(part_filename)
if default_storage.exists(local_path):
url = default_storage.url(local_path)
return local_path, url
return None, '{} not found.'.format(part_filename)

def download_part_file(self, part_filename):
storage = get_multi_object_storage()
if not storage:
msg = "Not found {} file, and not remote storage set".format(part_filename)
return None, msg
local_path = self.obj.get_replay_part_file_local_storage_path(part_filename)
remote_path = self.obj.get_replay_part_file_relative_path(part_filename)

# 保存到storage的路径
target_path = os.path.join(default_storage.base_location, local_path)

target_dir = os.path.dirname(target_path)
if not os.path.isdir(target_dir):
make_dirs(target_dir, exist_ok=True)

ok, err = storage.download(remote_path, target_path)
if not ok:
msg = 'Failed download {} file: {}'.format(part_filename, err)
logger.error(msg)
return None, msg
url = default_storage.url(local_path)
return local_path, url

def get_part_file_path_url(self, part_filename):
local_path, url = self.find_local_part_file_path(part_filename)
if local_path is None:
local_path, url = self.download_part_file(part_filename)
return local_path, url

def prepare_offline_tar_file(self):
replay_meta_filename = '{}.replay.json'.format(self.obj.id)
meta_local_path, url_or_error = self.get_part_file_path_url(replay_meta_filename)
if not meta_local_path:
raise FileNotFoundError(f'{replay_meta_filename} not found: {url_or_error}')
meta_local_abs_path = os.path.join(default_storage.base_location, meta_local_path)
with open(meta_local_abs_path, 'r') as f:
meta_data = json.load(f)
if not meta_data:
raise FileNotFoundError(f'{replay_meta_filename} is empty')
part_filenames = [part_file.get('name') for part_file in meta_data.get('files', [])]
for part_filename in part_filenames:
if not part_filename:
continue
local_path, url_or_error = self.get_part_file_path_url(part_filename)
if not local_path:
raise FileNotFoundError(f'{part_filename} not found: {url_or_error}')
dir_path = os.path.dirname(meta_local_abs_path)
offline_filename = '{}.tar'.format(self.obj.id)
offline_filename_abs_path = os.path.join(dir_path, offline_filename)
if not os.path.exists(offline_filename_abs_path):
with tarfile.open(offline_filename_abs_path, 'w') as f:
f.add(str(meta_local_abs_path), arcname=replay_meta_filename)
for part_filename in part_filenames:
local_abs_path = os.path.join(dir_path, part_filename)
f.add(local_abs_path, arcname=part_filename)
return open(offline_filename_abs_path, 'rb')
3 changes: 2 additions & 1 deletion apps/ops/signal_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from common.utils.connection import RedisPubSub
from jumpserver.utils import get_current_request
from orgs.utils import get_current_org_id, set_current_org
from .ansible.runner import interface
from .celery import app
from .models import CeleryTaskExecution, CeleryTask, Job
from .ansible.runner import interface

logger = get_logger(__name__)

Expand Down Expand Up @@ -63,6 +63,7 @@ def check_registered_tasks(*args, **kwargs):
'common.utils.verify_code.send_sms_async', 'assets.tasks.nodes_amount.check_node_assets_amount_period_task',
'users.tasks.check_user_expired', 'orgs.tasks.refresh_org_cache_task',
'terminal.tasks.upload_session_replay_to_external_storage', 'terminal.tasks.clean_orphan_session',
'terminal.tasks.upload_session_replay_file_to_external_storage',
'audits.tasks.clean_audits_log_period', 'authentication.tasks.clean_django_sessions'
]

Expand Down
38 changes: 25 additions & 13 deletions apps/terminal/api/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from common.drf.filters import DatetimeRangeFilterBackend
from common.drf.renders import PassthroughRenderer
from common.permissions import IsServiceAccount
from common.storage.replay import ReplayStorageHandler
from common.storage.replay import ReplayStorageHandler, SessionPartReplayStorageHandler
from common.utils import data_to_json, is_uuid, i18n_fmt
from common.utils import get_logger, get_object_or_none
from common.views.mixins import RecordViewLogMixin
Expand Down Expand Up @@ -124,33 +124,37 @@ def prepare_offline_file(session, local_path):
os.chdir(current_dir)
return file

def get_storage(self):
return ReplayStorageHandler(self.get_object())

@action(methods=[GET], detail=True, renderer_classes=(PassthroughRenderer,), url_path='replay/download',
url_name='replay-download')
def download(self, request, *args, **kwargs):
storage = self.get_storage()
session = self.get_object()
storage = ReplayStorageHandler(session)
local_path, url = storage.get_file_path_url()
if local_path is None:
# url => error message
return Response({'error': url}, status=404)

file = self.prepare_offline_file(storage.obj, local_path)
# 如果获取的录像文件类型是 .replay.json 则使用 part 的方式下载
if url.endswith('.replay.json'):
# part 的方式录像存储, 通过 part_storage 的方式下载
part_storage = SessionPartReplayStorageHandler(session)
file = part_storage.prepare_offline_tar_file()
else:
file = self.prepare_offline_file(session, local_path)
response = FileResponse(file)
response['Content-Type'] = 'application/octet-stream'
# 这里要注意哦,网上查到的方法都是response['Content-Disposition']='attachment;filename="filename.py"',
# 但是如果文件名是英文名没问题,如果文件名包含中文,下载下来的文件名会被改为url中的path。
filename = escape_uri_path('{}.tar'.format(storage.obj.id))
filename = escape_uri_path('{}.tar'.format(session.id))
disposition = "attachment; filename*=UTF-8''{}".format(filename)
response["Content-Disposition"] = disposition

detail = i18n_fmt(
REPLAY_OP, self.request.user, _('Download'), str(storage.obj)
REPLAY_OP, self.request.user, _('Download'), str(session)
)
self.record_logs(
[storage.obj.asset_id], ActionChoices.download, detail,
model=Session, resource_display=str(storage.obj)
[session.asset_id], ActionChoices.download, detail,
model=Session, resource_display=str(session)
)
return response

Expand Down Expand Up @@ -197,7 +201,7 @@ def get_queryset(self):
# so we need to use select_for_update only for have not prefetch_related and annotate
queryset = queryset.select_for_update()
return queryset

def perform_create(self, serializer):
if hasattr(self.request.user, 'terminal'):
serializer.validated_data["terminal"] = self.request.user.terminal
Expand Down Expand Up @@ -245,6 +249,9 @@ def get_replay_data(session, url):
tp = 'asciicast'
elif url.endswith('.replay.mp4'):
tp = 'mp4'
elif url.endswith('replay.json'):
# 新版本将返回元数据信息
tp = 'parts'
elif (getattr(session.terminal, 'type', None) in all_guacamole_types) or \
(session.protocol in ('rdp', 'vnc')):
tp = 'guacamole'
Expand Down Expand Up @@ -281,9 +288,14 @@ def async_callback(self, *args, **kwargs):
def retrieve(self, request, *args, **kwargs):
session_id = kwargs.get('pk')
session = get_object_or_404(Session, id=session_id)
part_filename = request.query_params.get('part_filename')
if part_filename:
storage = SessionPartReplayStorageHandler(session)
local_path, url = storage.get_part_file_path_url(part_filename)
else:
storage = ReplayStorageHandler(session)
local_path, url = storage.get_file_path_url()

storage = ReplayStorageHandler(session)
local_path, url = storage.get_file_path_url()
if local_path is None:
# url => error message
return Response({"error": url}, status=404)
Expand Down
38 changes: 28 additions & 10 deletions apps/terminal/models/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class LOGIN_FROM(models.TextChoices):
upload_to = 'replay'
ACTIVE_CACHE_KEY_PREFIX = 'SESSION_ACTIVE_{}'
LOCK_CACHE_KEY_PREFIX = 'TOGGLE_LOCKED_SESSION_{}'
SUFFIX_MAP = {1: '.gz', 2: '.replay.gz', 3: '.cast.gz', 4: '.replay.mp4'}
SUFFIX_MAP = {2: '.replay.gz', 3: '.cast.gz', 4: '.replay.mp4', 5: '.replay.json'}
DEFAULT_SUFFIXES = ['.replay.gz', '.cast.gz', '.gz', '.replay.mp4']

# Todo: 将来干掉 local_path, 使用 default storage 实现
Expand All @@ -75,22 +75,22 @@ def get_local_storage_path_by_suffix(self, suffix='.cast.gz'):
"""
local_path: replay/2021-12-08/session_id.cast.gz
通过后缀名获取本地存储的录像文件路径
:param suffix: .cast.gz | '.replay.gz' | '.gz'
:param suffix: .cast.gz | '.replay.gz'
:return:
"""
rel_path = self.get_relative_path_by_suffix(suffix)
if suffix == '.gz':
# 兼容 v1 的版本
return rel_path
return os.path.join(self.upload_to, rel_path)

def get_relative_path_by_suffix(self, suffix='.cast.gz'):
"""
relative_path: 2021-12-08/session_id.cast.gz
通过后缀名获取外部存储录像文件路径
:param suffix: .cast.gz | '.replay.gz' | '.gz'
:param suffix: .cast.gz | '.replay.gz' | '.replay.json'
:return:
"""
if suffix == '.replay.json':
meta_filename = str(self.id) + suffix
return self.get_replay_part_file_relative_path(meta_filename)
date = self.date_start.strftime('%Y-%m-%d')
return os.path.join(date, str(self.id) + suffix)

Expand Down Expand Up @@ -172,17 +172,35 @@ def terminal_display(self):
display = self.terminal.name if self.terminal else ''
return display

def get_replay_dir_relative_path(self):
date = self.date_start.strftime('%Y-%m-%d')
return os.path.join(date, str(self.id))

def get_replay_part_file_relative_path(self, filename):
return os.path.join(self.get_replay_dir_relative_path(), filename)

def get_replay_part_file_local_storage_path(self, filename):
return os.path.join(self.upload_to, self.get_replay_part_file_relative_path(filename))

def save_replay_to_storage_with_version(self, f, version=2):
suffix = self.SUFFIX_MAP.get(version, '.cast.gz')
local_path = self.get_local_storage_path_by_suffix(suffix)
if version <= 4:
# compatible old API and deprecated in future version
suffix = self.SUFFIX_MAP.get(version, '.cast.gz')
rel_path = self.get_relative_path_by_suffix(suffix)
local_path = self.get_local_storage_path_by_suffix(suffix)
else:
# 文件名依赖 上传的文件名,不再使用默认的文件名
filename = f.name
rel_path = self.get_replay_part_file_relative_path(filename)
local_path = self.get_replay_part_file_local_storage_path(filename)
try:
name = default_storage.save(local_path, f)
except OSError as e:
return None, e

if settings.SERVER_REPLAY_STORAGE:
from terminal.tasks import upload_session_replay_to_external_storage
upload_session_replay_to_external_storage.delay(str(self.id))
from terminal.tasks import upload_session_replay_file_to_external_storage
upload_session_replay_file_to_external_storage.delay(str(self.id), local_path, rel_path)
return name, None

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion apps/terminal/serializers/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Meta(SessionSerializer.Meta):

class ReplaySerializer(serializers.Serializer):
file = serializers.FileField(allow_empty_file=True)
version = serializers.IntegerField(write_only=True, required=False, min_value=2, max_value=4)
version = serializers.IntegerField(write_only=True, required=False, min_value=2, max_value=5)


class SessionJoinValidateSerializer(serializers.Serializer):
Expand Down
2 changes: 1 addition & 1 deletion apps/terminal/signal_handlers/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@receiver(pre_save, sender=Session)
def on_session_pre_save(sender, instance, **kwargs):
def on_session_pre_save(sender, instance,**kwargs):
if instance.need_update_cmd_amount:
instance.cmd_amount = instance.compute_command_amount()

Expand Down
21 changes: 21 additions & 0 deletions apps/terminal/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,27 @@ def upload_session_replay_to_external_storage(session_id):
return


@shared_task(
verbose_name=_('Upload session replay part file to external storage'),
description=_(
"""If SERVER_REPLAY_STORAGE is configured in the config.txt, session commands and
recordings will be uploaded to external storage"""
))
def upload_session_replay_file_to_external_storage(session_id, local_path, remote_path):
abs_path = default_storage.path(local_path)
ok, err = server_replay_storage.upload(abs_path, remote_path)
if not ok:
logger.error(f'Session replay file {local_path} upload to external error: {err}')
return

try:
default_storage.delete(local_path)
except:
pass
return



@shared_task(
verbose_name=_('Run applet host deployment'),
activity_callback=lambda self, did, *args, **kwargs: ([did],),
Expand Down

0 comments on commit 134f1a4

Please sign in to comment.