Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
chore: cleanup in the rest of the bot framework
Browse files Browse the repository at this point in the history
  • Loading branch information
tafaust committed Apr 10, 2022
1 parent 073afc4 commit 86678aa
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 134 deletions.
105 changes: 61 additions & 44 deletions src/execution/controller.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import asyncio
from tempfile import NamedTemporaryFile
import base64
from typing import AsyncIterator, Optional
from typing import Union

from PIL.GifImagePlugin import GifImageFile
from asyncpraw.models import Message
from asyncpraw.reddit import Comment
from imgurpython import ImgurClient
# from imgurpython import ImgurClient

import src.execution.task as t
from src.execution.result import Result
from src.client.imgur import ImgurClient
from src.model.result import Result
from src.client.reddit import RedditClient
from src.model.execution_mode import ExecutionMode
from src.model.media_type import MediaType
from src.model.task_state import TaskConfigState
from src.model.task_state import TaskState
from src.util import decorator, config
Expand All @@ -26,8 +27,7 @@ def log_broad_exception(err, logger=cut_logger) -> None:


class AioController(object):
# todo asyncio queue's: https://stackoverflow.com/a/24704950/2402281
# todo reduce praw's Message model to make it pickle'able for multiprocessing.Queue to work
# asyncio queue's: https://stackoverflow.com/a/24704950/2402281
input_queue: asyncio.Queue
output_queue: asyncio.Queue

Expand All @@ -39,18 +39,18 @@ def __init__(self, input_queue: asyncio.Queue, output_queue: asyncio.Queue,
self.reddit = None
self.imgur = None

def _init_reddit_host(self) -> None:
def _init_reddit_client(self) -> None:
if self.reddit is None:
root_logger.info('Initializing async reddit client.')
self.reddit: RedditClient = RedditClient()

def _init_imgur_host(self) -> None:
def _init_imgur_client(self) -> None:
if self.imgur is None:
root_logger.info('Initializing sync imgur client.')
self.imgur: ImgurClient = ImgurClient(config.IMGUR_CLIENT_ID, config.IMGUR_CLIENT_SECRET)

async def run(self, *args, **kwargs) -> None:
self._init_reddit_host()
self._init_reddit_client()
root_logger.debug('Calling controller.run ...')
await self.fetch()
await self.work()
Expand All @@ -60,12 +60,12 @@ async def fetch(self) -> None:
"""Must not be called more than every 2 seconds to not violate reddit rate limits.
"""
# todo custom logger
self._init_reddit_host()
self._init_reddit_client()
root_logger.info('Fetching new messages...')
try:
await self._fill_task_queue_from_reddit()
except Exception as err:
log_broad_exception(err, root_logger)
log_broad_exception(err, logger=root_logger)

@decorator.run_in_executor
def work(self) -> None:
Expand Down Expand Up @@ -95,26 +95,30 @@ def work(self) -> None:

# @decorator.run_in_executor
async def upload_and_answer(self) -> None:
_result: Result
if self._mode == ExecutionMode.TEST:
_result: Result
_result = self._read_result_from_output_queue()
_result = self._read_result_from_output_queue(logger=upload_logger)
if _result is None: # fixme
return
filename = 'test.gif'
filename = 'test.gif' # fixme change extension by hand when in TEST mode
with open(filename, mode='wb') as fp:
# save GIF into named temp file for upload with deprecated imgur lib...
_result.gif.save(fp=fp, format='GIF', save_all=True, duration=_result.gif_duration)
# if _result.media_type == MediaType.GIF:
# # save GIF into named temp file for upload with deprecated imgur lib...
# _result.media_stream.save(fp=fp, format='GIF', save_all=True)
# else:
fp.write(_result.media_stream.getvalue())
upload_logger.debug(f'Created file: {filename}')
return
self._init_imgur_host()
self._init_reddit_host()
_result: Result
_result = self._read_result_from_output_queue()
self._init_imgur_client() # make sure imgur client is connected
self._init_reddit_client() # make sure reddit client is connected
_result = self._read_result_from_output_queue(logger=upload_logger)
if _result is None: # fixme
upload_logger.warning('Received NoneType result.')
return
_result = self._upload_to_imgur(result=_result)
await self._answer_in_reddit(result=_result)
else:
upload_logger.info(f'Uploading result: {_result}')
upload_link = await self._upload_to_imgur(result=_result)
await self._answer_in_reddit(message=_result.message, upload_link=upload_link)

async def _fill_task_queue_from_reddit(self) -> None:
if not await self.reddit.has_new_message():
Expand Down Expand Up @@ -165,13 +169,15 @@ def _read_from_input_queue(self) -> t.Task:
else:
return _task

# noinspection PyMethodMayBeStatic
def _exert_task(self, task: t.Task) -> Result:
if task.is_state(TaskState.VALID):
_gif: GifImageFile
# _gif: GifImageFile
cut_logger.info('Handling task...')
try:
_gif, _gif_duration = task.handle()
return Result(message=task.config.message, gif=_gif, gif_duration=_gif_duration)
result: Result = task.handle()
return result
# return Result(message=task.config.message, gif=_gif, gif_duration=_gif_duration)
except TaskFailureException as err:
cut_logger.error(f'Task failed: {err}')
except Exception as err:
Expand All @@ -193,43 +199,54 @@ def _write_result_to_output_queue(self, result: Result) -> bool:
return True
return False

def _read_result_from_output_queue(self) -> Optional[Result]:
def _read_result_from_output_queue(self, logger) -> Optional[Result]:
# todo custom logger
_result: Result
try:
upload_logger.info('Attempting to immediately get task from output queue without blocking...')
_result = self.output_queue.get_nowait()
logger.info('Attempting to immediately get task from output queue without blocking...')
_result: Result = self.output_queue.get_nowait()
except ValueError:
upload_logger.error(f'Queue is closed.')
logger.error(f'Queue is closed.')
except asyncio.QueueEmpty:
upload_logger.warning(f'Queue is empty.')
upload_logger.debug(f'Output queue: {self.output_queue}')
logger.warning(f'Queue is empty.')
logger.debug(f'Output queue: {self.output_queue}')
except Exception as err:
log_broad_exception(err)
else:
return _result

# @decorator.run_in_executor
def _upload_to_imgur(self, result: Result) -> Result:
# todo custom logger
async def _upload_to_imgur(self, result: Result) -> str:
# self.imgur
# todo worth to have this async because of io
with NamedTemporaryFile(mode='wb', suffix='.gif') as fp:
# save GIF into named temp file for upload with deprecated imgur lib...
result.gif.save(fp=fp, format='GIF', save_all=True, duration=result.gif_duration)
res = self.imgur.upload_from_path(path=fp.name, anon=False)
upload_logger.debug(f'Upload to imgur: {res.get("link")}')
# todo error handling with res
result.gif_link = res.get('link')
return result
# with NamedTemporaryFile(mode='wb', suffix='.gif') as fp:
# save GIF into named temp file for upload with deprecated imgur lib...
# result.gif.save(fp=fp, format='GIF', save_all=True, duration=result)
# res = self.imgur.upload_from_path(path=fp.name, anon=False)
anon = False
if result.media_type == MediaType.GIF:
payload = {'image': base64.b64encode(result.media_stream.getvalue()), 'type': 'base64'}
anon = True
# payload = {'image': result.media_stream}
else:
payload = {
'type': 'file',
'disable_audio': '0',
'video': result.media_stream
}
res = self.imgur.upload(upload_payload=payload, anon=anon)
upload_logger.debug(f'Upload to imgur: {res.get("link")}')
# todo error handling with res
return res.get('link')

# @decorator.run_in_executor
async def _answer_in_reddit(self, result: Result) -> None:
# noinspection PyMethodMayBeStatic
async def _answer_in_reddit(self, message: Message, upload_link: str) -> None:
# todo need a custom result data type to answer the message T_T
# reply with link to the just cut gif and mark as unread
issue_link = f'https://www.reddit.com/message/compose/?to=domac&subject={config.REDDIT_USERNAME}%20issue&message=' \
f'Add a link to the gif or comment in your message%2C I%27m not always sure which request is ' \
f'being reported. Thanks for helping me out! '
bot_footer = f"---\n\n^(I am a bot.) [^(Report an issue)]({issue_link})"
await result.message.reply(f'Here is your cut GIF: {result.gif_link}\n{bot_footer}')
await message.reply(f'Here is your cut GIF: {upload_link}\n{bot_footer}')
# m.mark_read() # done
upload_logger.info('Reddit reply sent!')
102 changes: 62 additions & 40 deletions src/execution/task.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,76 @@
import math
import os
import re
from dataclasses import dataclass
from io import BytesIO
from typing import Dict
from typing import Dict, Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import PIL
import requests
from PIL.GifImagePlugin import GifImageFile
from PIL.Image import Image
from praw.models import Message
from praw.models import Submission
from asyncpraw.models import Message
from asyncpraw.models import Submission

from src.handler import base
from src.handler.gif import GifCutHandler
from src.handler.video import VideoCutHandler
import src.handler as handler_pkg
import src.model.result as result_pkg
from src import gif_utilities
from src.model.media_type import MediaType
from src.model.task_state import TaskConfigState
from src.model.task_state import TaskState
from src.util import exception
from src.util.aux import noop_image
from src.util.aux import Watermark
from src.util.aux import noop_image, fix_start_end_swap
from src.util.aux import watermark_image
from src.util.exception import TaskFailureException
from src.util.logger import root_logger


class TaskConfig(object):
@dataclass
class TaskConfig:
"""A task configuration which is not serializable.
Attributes:
message The reddit message object.
media_type The media type of the resource requested for cutting.
start The start time in milliseconds from where to cut the MediaType.
end The end time in milliseconds to stop the cut of the MediaType.
watermark An optional callable that watermarks the cut media.
state The state of this :class:~`TaskConfig`.
is_video A flag indicating if the media is a video.
is_gif A flag indicating if the media is a gif.
is_crosspost A flag indicating if the media is crossposted.
media_url The url to the media.
duration The total duration of the media in seconds read from the `message`.
extension The file extension of the media.
"""
message: Message
media_type: MediaType
start: float
end: Optional[float]
watermark: Callable[[PIL.Image.Image], PIL.Image.Image]

def __init__(
self, message: Message, start: float, end: float, media_type: MediaType, watermark:
self, message: Message, start: float, end: Optional[float], media_type: MediaType, watermark:
Optional[Watermark] = None
):
self.message = message
self.media_type = media_type
self.start = start
self.end = end
self.watermark = noop_image if watermark is None else lambda img: watermark_image(img, watermark)
self._state = TaskConfigState.VALID

self.__is_video = self.media_type in [MediaType.MP4, MediaType.MOV, MediaType.WEBM]
self.__is_gif = self.media_type == MediaType.GIF
if hasattr(message.submission, 'crosspost_parent'):
self.__is_crosspost = self.message.submission.crosspost_parent is not None
self.__is_crosspost = message.submission.crosspost_parent is not None
else:
self.__is_crosspost = False
start_ms, end_ms = fix_start_end_swap(start=start, end=end)
start_ms = max(start_ms, 0) # put a realistic lower bound on end
end_ms = min(end_ms or math.inf, self.duration * 1000) # put a realistic upper bound on end
self.start = start_ms
self.end = end_ms
self.watermark = noop_image if watermark is None else lambda img: watermark_image(img, watermark)
self._state = TaskConfigState.VALID

# there is no advantage in using a slotted class, thus resorting to __dict__
# self.__dict__ = {
Expand All @@ -62,7 +88,7 @@ def __init__(
# def __getattr__(self, values):
# yield from [getattr(self, i) for i in values.split('_')]

def __str__(self) -> str:
def __repr__(self) -> str:
return f'TaskConfig(message: {self.message}, media_type: {self.media_type}, start: {self.start}, ' \
f'end: {self.end}, watermark: {self.watermark}, state: {self.state}, is_video: {self.is_video}, ' \
f'is_gif: {self.is_gif}, is_crosspost: {self.is_crosspost}, duration: {self.duration}, extension: ' \
Expand Down Expand Up @@ -93,7 +119,7 @@ def media_url(self) -> str:
_submission: Submission = self.message.submission
if self.is_gif:
if self.is_crosspost:
return '' # dunno
return '' # todo
else:
return _submission.url
elif self.is_video:
Expand All @@ -113,7 +139,16 @@ def duration(self) -> Optional[float]:
# todo do this in __init__ and store in a "_variable"
if self.is_gif:
# AFAIK there is no duration sent when we are dealing with a GIF
return None
with requests.get(self.media_url, stream=True) as resp:
if resp.ok:
self._state = TaskConfigState.VALID
# read whole file via StreamReader into BytesIO
_stream = BytesIO(resp.raw.read())
_stream.seek(0)
return gif_utilities.get_gif_duration(image=PIL.Image.open(_stream))
else:
self._state = TaskConfigState.INVALID
return math.nan
elif self.is_video:
_submission: Submission = self.message.submission
if self.is_crosspost:
Expand Down Expand Up @@ -189,7 +224,7 @@ def __parse_start_and_end(cls, message: Message) -> Dict[str, float]:
pattern = re.compile(r'(s|start)=([\d]+) (e|end)=([\d]+)', re.IGNORECASE)
matches = pattern.search(message.body)
if matches is None:
root_logger.warn('Skipping message because no match was found.')
root_logger.warning('Skipping message because no match was found.')
cls.state = TaskConfigState.INVALID
return {}
root_logger.debug(f'Found pattern matches: {matches.groups()}')
Expand Down Expand Up @@ -238,24 +273,21 @@ def is_state(self, state: Union[TaskState, List[TaskState]]) -> bool:
def _select_handler(self):
mt: MediaType = self.__config.media_type
if mt == MediaType.GIF:
self._task_handler = GifCutHandler()
self._task_handler = handler_pkg.gif.GifCutHandler()
elif mt in [MediaType.MP4, MediaType.MOV, MediaType.WEBM]:
self._task_handler = VideoCutHandler()
self._task_handler = handler_pkg.video.VideoCutHandler()
else:
self._task_state = TaskState.DROP
# self._task_handler = TestCutHandler()
root_logger.warn(f'No handler for media type: {mt}')
root_logger.warning(f'No handler for media type: {mt}')

def handle(self) -> Tuple[GifImageFile, float]:
def handle(self) -> result_pkg.Result:
_stream: Optional[BytesIO] = self._fetch_stream()
if self._task_state == TaskState.INVALID:
raise TaskFailureException('Failed to fetch stream from host!')
image: List[Image]
avg_fps: float
image, avg_fps = self._task_handler.cut(stream=_stream, config=self.__config)
gif: GifImageFile = base.post_cut_hook(r=(image, avg_fps))
_result: result_pkg.Result = self._task_handler.cut(stream=_stream, config=self.__config)
self._task_state = TaskState.DONE
return gif, avg_fps
return _result

def _fetch_stream(self) -> Optional[BytesIO]:
_stream: BytesIO
Expand All @@ -268,16 +300,6 @@ def _fetch_stream(self) -> Optional[BytesIO]:
self._task_state = TaskState.INVALID
return None
return _stream
# if self.__config.is_video:
# with open(f'foo.mp4', 'wb') as f:
# f.write(requests.get(url, stream=True).raw.read())
# with open(f'foo.mp4', 'rb') as f:
# _stream = BytesIO(f.read())
# elif self.__config.is_gif:
# with requests.get(url, stream=True) as r:
# _stream = Image.open(r.raw)
# else:
# root_logger.error('No valid input! Neither received a video or gif.')

@property
def config(self):
Expand Down
Loading

0 comments on commit 86678aa

Please sign in to comment.