Skip to content

Commit

Permalink
DONE - pre-commit and pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
Crivella committed Jul 21, 2023
1 parent c964dd1 commit dc94f2b
Show file tree
Hide file tree
Showing 12 changed files with 253 additions and 121 deletions.
35 changes: 15 additions & 20 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,18 @@ repos:
- id: isort
exclude: *exclude

# - repo: local
# hooks:
# - id: pylint
# name: pylint
# entry: pylint
# language: system
# types: [python]
# exclude: >
# (?x)^(
# docs/.*|
# tests/.*(?<!\.py)$
# )$
# args:
# - "--load-plugins=pylint_pytest"
# - "--load-plugins=pylint_django"
# # - "--disable=W0212" # protected-access. Needed for parser
# # - "--disable=W0612" # unused-variable. Needed for clarity when expanding variables
# # - "--disable=R1710" # inconsistent return statements. Needed for rasing errors in workchain
# # - "--disable=C0411" # import order. Isort will take care of this
# # - "--disable=W0707" # Consider explicitly re-raising.
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [python]
exclude: >
(?x)^(
manage\.py|
mysite/.*|
.*/migrations/.*|
docs/.*|
tests/.*(?<!\.py)$
)$
2 changes: 1 addition & 1 deletion ocr_translate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
# #
# Home: https://github.com/Crivella/ocr_translate #
###################################################################################

"""OCR and translation of images."""

__version__ = '0.1.2'
6 changes: 5 additions & 1 deletion ocr_translate/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,26 @@
# #
# Home: https://github.com/Crivella/ocr_translate #
###################################################################################
"""Admin interface for ocr_translate app."""
from django.contrib import admin

from . import models as m

# Register your models here.

class LanguageAdmin(admin.ModelAdmin):
"""Admin interface for Language model"""
list_display = ('name', 'iso1', 'iso2b', 'iso2t', 'iso3')

class OCRBoxModelAdmin(admin.ModelAdmin):
"""Admin interface for OCRBoxModel model"""
list_display = ('name',)

class OCRModelAdmin(admin.ModelAdmin):
"""Admin interface for OCRModel model"""
list_display = ('name',)

class TSLModelAdmin(admin.ModelAdmin):
"""Admin interface for TSLModel model"""
list_display = ('name',)


Expand Down
4 changes: 3 additions & 1 deletion ocr_translate/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# #
# Home: https://github.com/Crivella/ocr_translate #
###################################################################################
"""App configuration for ocr_translate."""
from django.apps import AppConfig


class OCR_TSLconfig(AppConfig):
class OCRTSLconfig(AppConfig):
"""App configuration for ocr_translate."""
default_auto_field = 'django.db.models.BigAutoField'
name = 'ocr_translate'
131 changes: 96 additions & 35 deletions ocr_translate/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# #
# Home: https://github.com/Crivella/ocr_translate #
###################################################################################
"""Messaging and worker+queue system for ocr_translate."""
import logging
import queue
import threading
Expand All @@ -26,23 +27,26 @@


class NotHandled():
pass
"""Dummy object to be used as default response of an unresolved message."""

class Message():
"""Message object to be used in WorkerMessageQueue. This class is used to send a message in the queue, but
also allow the sending function to wait for the response of the message."""
NotHandled = NotHandled
def __init__(
self, id_: Hashable, msg: dict, handler: Callable,
batch_args: tuple = (), batch_kwargs: Iterable = ()
):
"""Message object to be used in WorkerMessageQueue.
Params:
id (Hashable): Message id. Used to identify messages with the same id.
Args:
id_ (Hashable): Message id. Used to identify messages with the same id.
msg (dict): Message to be passed to the handler.
handler (Callable): Handler function to be called with the message.
batch_args (tuple, optional): Indexes of the args to be batched. Defaults to ().
batch_kwargs (Iterable, optional): Keys of the kwargs to be batched. Defaults to ().
"""
self.id = id_
self.id_ = id_
self.msg = msg
self.handler = handler
self.batch_args = batch_args
Expand All @@ -61,20 +65,21 @@ def resolve(self):

def batch_resolve(self, others: Iterable['Message']):
"""Resolve multiple messages with one call to the handler.
The handler must be able to handle the specified batched args and kwargs, as either the expected type or a list of the expected type.
The handler must be able to handle the specified batched args and kwargs,
as either the expected type or a list of the expected type.
The handler must return a list of the same length as the number of messages to be resolved, with the same order.
"""
logger.debug(f'MSG Batch Resolving {self.msg} with {len(others)} other messages')
# Check if these checks are necessary (maybe just let the handler fail)
# Main problem would be running messages with different non batched args that produce worng results
# But having all these checks might slow down the batching too much
if any([_.handler != self.handler for _ in others]):
if any(_.handler != self.handler for _ in others):
raise ValueError('All messages must have the same handler')
check = len(self.msg['args'])
if any([len(_.msg['args']) != check for _ in others]):
if any(len(_.msg['args']) != check for _ in others):
raise ValueError('All messages must have the same number of args')
check = set(self.msg['kwargs'].keys())
if any([set(_.msg['kwargs'].keys()) != check for _ in others]):
if any(set(_.msg['kwargs'].keys()) != check for _ in others):
raise ValueError('All messages must have the same kwargs')
# Should also check that the non-batched args and kwargs are the same for all messages

Expand All @@ -90,21 +95,39 @@ def batch_resolve(self, others: Iterable['Message']):
respones = self.handler(*args, **kwargs)
for msg, r in zip([self, *others], respones):
logger.debug(f'MSG Batch Resolved {msg.msg} -> {r}')
msg._response = r
msg.set_response(r)

# Make sure to dereference the message to avoid keeping raw images in memory
# since i am gonna keep the message in the queue after it is resolved (for msg caching)
del msg.msg

@property
def is_resolved(self):
def is_resolved(self) -> bool:
"""Whether the message has been resolved or not."""
return self._response is not NotHandled

def set_response(self, response):
"""Set the response of the message."""
self._response = response

def response(self, timeout: float = 0, poll: float = 0.2):
"""Get the response of the message.
Args:
timeout (float, optional): Timeout in seconds to wait for the message to be resolved.
Defaults to 0 (no timeout).
poll (float, optional): Polling interval in seconds. Defaults to 0.2.
Raises:
TimeoutError: If the message is not resolved after the timeout.
Returns:
Any: The response of the message (return value of the handler called on the msg content).
"""
start = time.time()

while not self.is_resolved:
if timeout > 0 and time.time() - start > timeout:
if time.time() - start > timeout > 0:
raise TimeoutError('Message resolution timed out')
time.sleep(poll)

Expand All @@ -114,17 +137,19 @@ def __str__(self):
return f'Message({self.msg}), Handler: {self.handler.__name__}'

class Worker():
def __init__(self, q: queue.SimpleQueue[Message]):
self.q = q
"""Worker object to be used in WorkerMessageQueue."""
def __init__(self, attached_queue: queue.SimpleQueue[Message]):
self.queue = attached_queue
self.kill = False
self.running = False
self.thread = None

def _worker(self):
"""Worker function that consumes messages from the queue and resolves them."""
self.running = True
while not self.kill:
try:
msg = self.q.get(timeout=1)
msg = self.queue.get(timeout=1)
except queue.Empty:
continue
logger.debug(f'Worker consuming {msg}')
Expand All @@ -142,14 +167,20 @@ def _worker(self):
self.running = False

def start(self):
"""Start the worker thread."""
self.thread = threading.Thread(target=self._worker, daemon=True)
self.thread.start()

def stop(self):
"""Stop the worker thread."""
self.kill = True
self.thread.join()

class WorkerMessageQueue(queue.SimpleQueue):
"""Message queue with worker threads to resolve messages. This class extends queue.SimpleQueue, by adding:
- Message caching/reuse (When a new message with the same id is put in the queue, the old one is returned)
- Message batching (Messages with the same batch_id are grouped together and resolved with one handler call)
"""
def __init__(
self,
*args,
Expand All @@ -161,15 +192,16 @@ def __init__(
batch_args: tuple = (), batch_kwargs: Iterable = (),
**kwargs
):
"""
"""Create a new WorkerMessageQueue.
Args:
num_workers (int, optional): Number of workers to spawn. Defaults to 1.
reuse_msg (bool, optional): Whether to reuse messages with the same id. Defaults to True.
max_len (int, optional): Max number of messages in queue before starting to remove solved messages from cache.
Defaults to 0 (no limit).
max_len (int, optional): Max number of messages in queue before starting to remove solved messages
from cache. Defaults to 0 (no limit).
allow_batching (bool, optional): Whether to allow batching of messages. Defaults to False.
Batching is done by grouping messages with the same args, kwargs and handler (excluding the arguments to be batched).
batch_timeout (float, optional): Timeout for batching. Defaults to 0.5.
batch_timeout (float, optional): Timeout for batching. When get is called, wait `timeout` seconds for other
incoming messages. Defaults to 0.5.
batch_args (tuple, optional): Indexes of the args to be batched. Defaults to ().
batch_kwargs (Iterable, optional): Keys of the kwargs to be batched. Defaults to ().
"""
Expand All @@ -187,12 +219,26 @@ def __init__(
self.workers = [Worker(self) for _ in range(num_workers)]

def put(self, id_: Hashable, msg: dict, handler: Callable, batch_id: Hashable = None) -> Message:
"""Put a new message in the queue.
Args:
id_ (Hashable): Id of the message. Used to identify messages with the same id.
msg (dict): Message to be passed to the handler.
handler (Callable): Handler function to be called with the message.
batch_id (Hashable, optional): Id of the batch to which the message belongs. Defaults to None.
Raises:
NotImplementedError: If the max_len is reached.
Returns:
Message: The message object.
"""
if self.reuse_msg and id_ in self.registered:
logger.debug(f'Reusing message {id_}')
return self.registered[id_]
if self.max_len > 0 and self.qsize() > self.max_len:
# TODO: Remove solved messages from cache
# Only 1by1 or all?
# Remove solved messages from cache
# Only 1by1 or all?
raise NotImplementedError('Max len reached')

res = Message(id_, msg, handler, batch_args=self.batch_args, batch_kwargs=self.batch_kwargs)
Expand All @@ -208,36 +254,51 @@ def put(self, id_: Hashable, msg: dict, handler: Callable, batch_id: Hashable =
return res

def get(self, *args, **kwargs) -> Union[Message, list[Message]]:
"""Get a message or list of messages from the queue.
Returns:
Union[Message, list[Message]]: A message or a list of messages, depending on whether batching is enabled.
"""
msg = super().get(*args, **kwargs)
while msg.id in self.batch_resolve_flagged:
self.batch_resolve_flagged.remove(msg.id)
while msg.id_ in self.batch_resolve_flagged:
self.batch_resolve_flagged.remove(msg.id_)
msg = super().get(*args, **kwargs)

if self.allow_batching and msg.id in self.msg_to_batch_pool:
if self.allow_batching and msg.id_ in self.msg_to_batch_pool:
# Wait for more messages to come
logger.debug(f'Batching message {msg.id}')
logger.debug(f'Batching message {msg.id_}')
time.sleep(self.batch_timeout)

logger.debug(f'Batching message {msg.id} done')
pool_id = self.msg_to_batch_pool[msg.id]
logger.debug(f'Batching message {msg.id} pool id {pool_id}')
logger.debug(f'Batching message {msg.id_} done')
pool_id = self.msg_to_batch_pool[msg.id_]
logger.debug(f'Batching message {msg.id_} pool id {pool_id}')
pool = self.batch_pools.pop(pool_id)
logger.debug(f'Batching message {msg.id} pool {pool}')
logger.debug(f'Batching message {msg.id_} pool {pool}')
for msg in pool:
self.msg_to_batch_pool.pop(msg.id)
self.batch_resolve_flagged.append(msg.id)
self.msg_to_batch_pool.pop(msg.id_)
self.batch_resolve_flagged.append(msg.id_)
if len(pool) > 1:
return pool

return msg

def get_msg(self, msg_id: str):
"""Get a message from the cache. If the message is not in the cache, return None.
Args:
msg_id (str): Id of the message.
Returns:
_type_: The message object or None.
"""
return self.registered.get(msg_id, None)

def start_workers(self):
for w in self.workers:
w.start()
"""Start all the worker threads registered to this queue."""
for worker in self.workers:
worker.start()

def stop_workers(self):
for w in self.workers:
w.stop()
"""Stop all the worker threads registered to this queue."""
for worker in self.workers:
worker.stop()
4 changes: 3 additions & 1 deletion ocr_translate/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# #
# Home: https://github.com/Crivella/ocr_translate #
###################################################################################
"""Django models for the ocr_translate app."""
from django.db import models

lang_length = 32
LANG_LENGTH = 32

class OptionDict(models.Model):
"""Dictionary of options for OCR and translation"""
Expand Down Expand Up @@ -92,6 +93,7 @@ class BBox(models.Model):

@property
def lbrt(self):
"""Return the bounding box as a tuple of (left, bottom, right, top)"""
return self.l, self.b, self.r, self.t

def __str__(self):
Expand Down
Loading

0 comments on commit dc94f2b

Please sign in to comment.