Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max_attempts_at_message #395

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/examples/extending/broker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncGenerator, Union

from taskiq import AckableMessage, AsyncBroker, BrokerMessage
from taskiq import WrappedMessage, AsyncBroker, BrokerMessage


class MyBroker(AsyncBroker):
Expand All @@ -23,7 +23,7 @@ async def kick(self, message: BrokerMessage) -> None:
# Send a message.message.
pass

async def listen(self) -> AsyncGenerator[Union[bytes, AckableMessage], None]:
async def listen(self) -> AsyncGenerator[Union[bytes, WrappedMessage], None]:
while True:
# Get new message.
new_message: bytes = ... # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions taskiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.abc.schedule_source import ScheduleSource
from taskiq.acks import AckableMessage
from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.brokers.shared_broker import async_shared_broker
from taskiq.brokers.zmq_broker import ZeroMQBroker
Expand All @@ -24,7 +23,7 @@
TaskiqResultTimeoutError,
)
from taskiq.funcs import gather
from taskiq.message import BrokerMessage, TaskiqMessage
from taskiq.message import BrokerMessage, MessageMetadata, TaskiqMessage, WrappedMessage
from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
from taskiq.result import TaskiqResult
Expand All @@ -35,14 +34,14 @@

__version__ = version("taskiq")
__all__ = [
"AckableMessage",
"AsyncBroker",
"AsyncResultBackend",
"AsyncTaskiqDecoratedTask",
"AsyncTaskiqTask",
"BrokerMessage",
"Context",
"InMemoryBroker",
"MessageMetadata",
"NoResultError",
"PrometheusMiddleware",
"ResultGetError",
Expand All @@ -62,6 +61,7 @@
"TaskiqResultTimeoutError",
"TaskiqScheduler",
"TaskiqState",
"WrappedMessage",
"ZeroMQBroker",
"__version__",
"async_shared_broker",
Expand Down
17 changes: 10 additions & 7 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@

from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.acks import AckableMessage
from taskiq.decor import AsyncTaskiqDecoratedTask
from taskiq.events import TaskiqEvents
from taskiq.formatters.proxy_formatter import ProxyFormatter
from taskiq.message import BrokerMessage
from taskiq.message import BrokerMessage, WrappedMessage
from taskiq.result_backends.dummy import DummyResultBackend
from taskiq.serializers.json_serializer import JSONSerializer
from taskiq.state import TaskiqState
Expand Down Expand Up @@ -77,6 +76,7 @@ def __init__(
self,
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: Optional[Callable[[], str]] = None,
max_attempts_at_message: Optional[int] = None,
) -> None:
if result_backend is None:
result_backend = DummyResultBackend()
Expand Down Expand Up @@ -113,6 +113,7 @@ def __init__(
self.state = TaskiqState()
self.custom_dependency_context: Dict[Any, Any] = {}
self.dependency_overrides: Dict[Any, Any] = {}
self.max_attempts_at_message = max_attempts_at_message
# True only if broker runs in worker process.
self.is_worker_process = False
# True only if broker runs in scheduler process.
Expand Down Expand Up @@ -237,18 +238,20 @@ async def kick(
"""

@abstractmethod
def listen(self) -> AsyncGenerator[Union[bytes, AckableMessage], None]:
def listen(self) -> AsyncGenerator[Union[bytes, WrappedMessage], None]:
"""
This function listens to new messages and yields them.

This it the main point for workers.
This function is used to get new tasks from the network.

If your broker support acknowledgement, then you
should wrap your message in AckableMessage dataclass.
If your broker support acknowledgements (or negative acknowledgements),
then the returned message should implement the AckableMessage
(or NackableMessage) interface by implementing the `ack` (or
`nack`) callback.

If your messages was wrapped in AckableMessage dataclass,
taskiq will call ack when finish processing message.
If your message has an `ack` callbacks it will be called after the
message is processed.

:yield: incoming messages.
:return: nothing.
Expand Down
12 changes: 11 additions & 1 deletion taskiq/acks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,15 @@ class AckableMessage(BaseModel):
as a whole.
"""

data: bytes
ack: Callable[[], Union[None, Awaitable[None]]]


class NackableMessage(BaseModel):
"""
Message that can be negatively acknowledged.

Message that can be negatively acknowledged, e.g.
sent to a dead-letter queue, etc.
"""

nack: Callable[[], Union[None, Awaitable[None]]]
1 change: 1 addition & 0 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
ack_type=args.ack_type,
max_tasks_to_execute=args.max_tasks_per_child,
wait_tasks_timeout=args.wait_tasks_timeout,
max_attempts_at_message=broker.max_attempts_at_message,
**receiver_kwargs, # type: ignore
)
loop.run_until_complete(receiver.listen(shutdown_event))
Expand Down
40 changes: 40 additions & 0 deletions taskiq/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic import BaseModel

from taskiq.acks import AckableMessage, NackableMessage
from taskiq.labels import parse_label


Expand Down Expand Up @@ -42,3 +43,42 @@ class BrokerMessage(BaseModel):
task_name: str
message: bytes
labels: Dict[str, Any]


class MessageMetadata(BaseModel):
"""Incoming message metadata."""

delivery_count: Optional[int] = None


class WrappedMessage(BaseModel): # noqa: D101
message: bytes


class MessageWithMetadata(BaseModel): # noqa: D101
metadata: MessageMetadata


class WrappedMessageWithMetadata(WrappedMessage, MessageWithMetadata): # noqa: D101
...


class AckableWrappedMessage(WrappedMessage, AckableMessage): # noqa: D101
...


class AckableWrappedMessageWithMetadata( # noqa: D101
WrappedMessage,
AckableMessage,
MessageWithMetadata,
):
...


class AckableNackableWrappedMessageWithMetadata( # noqa: D101
WrappedMessage,
AckableMessage,
NackableMessage,
MessageWithMetadata,
):
...
46 changes: 37 additions & 9 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import anyio
from taskiq_dependencies import DependencyGraph

from taskiq.abc.broker import AckableMessage, AsyncBroker
from taskiq.abc.broker import AsyncBroker
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.acks import AcknowledgeType
from taskiq.acks import AckableMessage, AcknowledgeType, NackableMessage
from taskiq.context import Context
from taskiq.exceptions import NoResultError
from taskiq.message import TaskiqMessage
from taskiq.message import MessageWithMetadata, TaskiqMessage, WrappedMessage
from taskiq.receiver.params_parser import parse_params
from taskiq.result import TaskiqResult
from taskiq.state import TaskiqState
Expand Down Expand Up @@ -58,6 +58,7 @@ def __init__(
on_exit: Optional[Callable[["Receiver"], None]] = None,
max_tasks_to_execute: Optional[int] = None,
wait_tasks_timeout: Optional[float] = None,
max_attempts_at_message: Optional[int] = None,
) -> None:
self.broker = broker
self.executor = executor
Expand All @@ -72,6 +73,7 @@ def __init__(
self.known_tasks: Set[str] = set()
self.max_tasks_to_execute = max_tasks_to_execute
self.wait_tasks_timeout = wait_tasks_timeout
self.max_attempts_at_message = max_attempts_at_message
for task in self.broker.get_all_tasks().values():
self._prepare_task(task.task_name, task.original_func)
self.sem: "Optional[asyncio.Semaphore]" = None
Expand All @@ -86,7 +88,7 @@ def __init__(

async def callback( # noqa: C901, PLR0912
self,
message: Union[bytes, AckableMessage],
message: Union[bytes, WrappedMessage],
raise_err: bool = False,
) -> None:
"""
Expand All @@ -101,7 +103,33 @@ async def callback( # noqa: C901, PLR0912
:param raise_err: raise an error if cannot save result in
result_backend.
"""
message_data = message.data if isinstance(message, AckableMessage) else message
message_data = (
message.message if isinstance(message, WrappedMessage) else message
)
if isinstance(message, MessageWithMetadata):
message_metadata = message.metadata
else:
message_metadata = None

delivery_count = message_metadata.delivery_count if message_metadata else None
if (
delivery_count
and self.max_attempts_at_message
and delivery_count >= self.max_attempts_at_message
):
logger.error(
"Permitted number of attempts at processing message %s "
"has been exhausted after %s attempts.",
message_data,
self.max_attempts_at_message,
)
match message:
case NackableMessage():
await maybe_awaitable(message.nack())
case AckableMessage():
await maybe_awaitable(message.ack())
return

try:
taskiq_msg = self.broker.formatter.loads(message=message_data)
taskiq_msg.parse_labels()
Expand Down Expand Up @@ -331,7 +359,7 @@ async def listen(self, finish_event: asyncio.Event) -> None: # pragma: no cover
if self.run_startup:
await self.broker.startup()
logger.info("Listening started.")
queue: "asyncio.Queue[Union[bytes, AckableMessage]]" = asyncio.Queue()
queue: "asyncio.Queue[Union[bytes, WrappedMessage]]" = asyncio.Queue()

async with anyio.create_task_group() as gr:
gr.start_soon(self.prefetcher, queue, finish_event)
Expand All @@ -342,7 +370,7 @@ async def listen(self, finish_event: asyncio.Event) -> None: # pragma: no cover

async def prefetcher(
self,
queue: "asyncio.Queue[Union[bytes, AckableMessage]]",
queue: "asyncio.Queue[Union[bytes, WrappedMessage]]",
finish_event: asyncio.Event,
) -> None:
"""
Expand All @@ -354,7 +382,7 @@ async def prefetcher(
fetched_tasks: int = 0
iterator = self.broker.listen()
current_message: asyncio.Task[
Union[bytes, AckableMessage]
Union[bytes, WrappedMessage]
] = asyncio.create_task(
iterator.__anext__(), # type: ignore
)
Expand Down Expand Up @@ -394,7 +422,7 @@ async def prefetcher(

async def runner(
self,
queue: "asyncio.Queue[Union[bytes, AckableMessage]]",
queue: "asyncio.Queue[Union[bytes, WrappedMessage]]",
) -> None:
"""
Run tasks.
Expand Down
Loading