# Copyright 2016-2021 The NATS Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
import asyncio
from typing import (
TYPE_CHECKING,
AsyncIterator,
Awaitable,
Callable,
List,
Optional,
)
from uuid import uuid4
from nats import errors
# Default Pending Limits of Subscriptions
from nats.aio.msg import Msg
if TYPE_CHECKING:
from nats.js import JetStreamContext
DEFAULT_SUB_PENDING_MSGS_LIMIT = 512 * 1024
DEFAULT_SUB_PENDING_BYTES_LIMIT = 128 * 1024 * 1024
[docs]
class Subscription:
"""
A Subscription represents interest in a particular subject.
A Subscription should not be constructed directly, rather
`connection.subscribe()` should be used to get a subscription.
::
nc = await nats.connect()
# Async Subscription
async def cb(msg):
print('Received', msg)
await nc.subscribe('foo', cb=cb)
# Sync Subscription
sub = nc.subscribe('foo')
msg = await sub.next_msg()
print('Received', msg)
"""
def __init__(
self,
conn,
id: int = 0,
subject: str = '',
queue: str = '',
cb: Optional[Callable[[Msg], Awaitable[None]]] = None,
future: Optional[asyncio.Future] = None,
max_msgs: int = 0,
pending_msgs_limit: int = DEFAULT_SUB_PENDING_MSGS_LIMIT,
pending_bytes_limit: int = DEFAULT_SUB_PENDING_BYTES_LIMIT,
) -> None:
self._conn = conn
self._id = id
self._subject = subject
self._queue = queue
self._max_msgs = max_msgs
self._received = 0
self._cb = cb
self._future = future
self._closed = False
# Per subscription message processor.
self._pending_msgs_limit = pending_msgs_limit
self._pending_bytes_limit = pending_bytes_limit
self._pending_queue: asyncio.Queue[Msg] = asyncio.Queue(
maxsize=pending_msgs_limit
)
# If no callback, then this is a sync subscription which will
# require tracking the next_msg calls inflight for cancelling.
if cb is None:
self._pending_next_msgs_calls = {}
else:
self._pending_next_msgs_calls = None
self._pending_size = 0
self._wait_for_msgs_task = None
self._message_iterator = None
# For JetStream enabled subscriptions.
self._jsi: Optional[JetStreamContext._JSI] = None
@property
def subject(self) -> str:
"""
Returns the subject of the `Subscription`.
"""
return self._subject
@property
def queue(self) -> str:
"""
Returns the queue name of the `Subscription` if part of a queue group.
"""
return self._queue
@property
def messages(self) -> AsyncIterator[Msg]:
"""
Retrieves an async iterator for the messages from the subscription.
This is only available if a callback isn't provided when creating a
subscription.
::
nc = await nats.connect()
sub = await nc.subscribe('foo')
# Use `async for` which implicitly awaits messages
async for msg in sub.messages:
print('Received', msg)
"""
if not self._message_iterator:
raise errors.Error(
"cannot iterate over messages with a non iteration subscription type"
)
return self._message_iterator
@property
def pending_msgs(self) -> int:
"""
Number of delivered messages by the NATS Server that are being buffered
in the pending queue.
"""
return self._pending_queue.qsize()
@property
def pending_bytes(self) -> int:
"""
Size of data sent by the NATS Server that is being buffered
in the pending queue.
"""
return self._pending_size
@property
def delivered(self) -> int:
"""
Number of delivered messages to this subscription so far.
"""
return self._received
[docs]
async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
"""
:params timeout: Time in seconds to wait for next message before timing out.
:raises nats.errors.TimeoutError:
next_msg can be used to retrieve the next message from a stream of messages using
await syntax, this only works when not passing a callback on `subscribe`::
sub = await nc.subscribe('hello')
msg = await sub.next_msg(timeout=1)
"""
if self._conn.is_closed:
raise errors.ConnectionClosedError
if self._cb:
raise errors.Error(
'nats: next_msg cannot be used in async subscriptions'
)
task_name = str(uuid4())
try:
future = asyncio.create_task(
asyncio.wait_for(self._pending_queue.get(), timeout)
)
self._pending_next_msgs_calls[task_name] = future
msg = await future
except asyncio.TimeoutError:
if self._conn.is_closed:
raise errors.ConnectionClosedError
raise errors.TimeoutError
except asyncio.CancelledError:
if self._conn.is_closed:
raise errors.ConnectionClosedError
raise
else:
self._pending_size -= len(msg.data)
# For sync subscriptions we will consider a message
# to be done once it has been consumed by the client
# regardless of whether it has been processed.
self._pending_queue.task_done()
return msg
finally:
self._pending_next_msgs_calls.pop(task_name, None)
def _start(self, error_cb):
"""
Creates the resources for the subscription to start processing messages.
"""
if self._cb:
if not asyncio.iscoroutinefunction(self._cb) and \
not (hasattr(self._cb, "func") and asyncio.iscoroutinefunction(self._cb.func)):
raise errors.Error(
"nats: must use coroutine for subscriptions"
)
self._wait_for_msgs_task = asyncio.get_running_loop().create_task(
self._wait_for_msgs(error_cb)
)
elif self._future:
# Used to handle the single response from a request.
pass
else:
self._message_iterator = _SubscriptionMessageIterator(self)
[docs]
async def drain(self):
"""
Removes interest in a subject, but will process remaining messages.
"""
if self._conn.is_closed:
raise errors.ConnectionClosedError
if self._conn.is_draining:
raise errors.ConnectionDrainingError
if self._closed:
raise errors.BadSubscriptionError
await self._drain()
async def _drain(self) -> None:
try:
# Announce server that no longer want to receive more
# messages in this sub and just process the ones remaining.
await self._conn._send_unsubscribe(self._id)
# Roundtrip to ensure that the server has sent all messages.
await self._conn.flush()
if self._pending_queue:
# Wait until no more messages are left,
# then cancel the subscription task.
await self._pending_queue.join()
# stop waiting for messages
self._stop_processing()
# Subscription is done and won't be receiving further
# messages so can throw it away now.
self._conn._remove_sub(self._id)
except asyncio.CancelledError:
raise
finally:
self._closed = True
[docs]
async def unsubscribe(self, limit: int = 0):
"""
:param limit: Max number of messages to receive before unsubscribing.
Removes interest in a subject, remaining messages will be discarded.
If `limit` is greater than zero, interest is not immediately removed,
rather, interest will be automatically removed after `limit` messages
are received.
"""
if self._conn.is_closed:
raise errors.ConnectionClosedError
if self._conn.is_draining:
raise errors.ConnectionDrainingError
if self._closed:
raise errors.BadSubscriptionError
self._max_msgs = limit
if limit == 0 or (self._received >= limit
and self._pending_queue.empty()):
self._closed = True
self._stop_processing()
self._conn._remove_sub(self._id)
if not self._conn.is_reconnecting:
await self._conn._send_unsubscribe(self._id, limit=limit)
def _stop_processing(self) -> None:
"""
Stops the subscription from processing new messages.
"""
if self._wait_for_msgs_task and not self._wait_for_msgs_task.done():
self._wait_for_msgs_task.cancel()
if self._message_iterator:
self._message_iterator._cancel()
async def _wait_for_msgs(self, error_cb) -> None:
"""
A coroutine to read and process messages if a callback is provided.
Should be called as a task.
"""
assert self._cb, "_wait_for_msgs can be called only from _start"
while True:
try:
msg = await self._pending_queue.get()
self._pending_size -= len(msg.data)
try:
# Invoke depending of type of handler.
await self._cb(msg)
except asyncio.CancelledError:
# In case the coroutine handler gets cancelled
# then stop task loop and return.
break
except Exception as e:
# All errors from calling a handler
# are async errors.
if error_cb:
await error_cb(e)
finally:
# indicate the message finished processing so drain can continue.
self._pending_queue.task_done()
# Apply auto unsubscribe checks after having processed last msg.
if self._max_msgs > 0 and self._received >= self._max_msgs and self._pending_queue.empty:
self._stop_processing()
except asyncio.CancelledError:
break
class _SubscriptionMessageIterator:
def __init__(self, sub: Subscription) -> None:
self._sub: Subscription = sub
self._queue: asyncio.Queue[Msg] = sub._pending_queue
self._unsubscribed_future: asyncio.Future[bool] = asyncio.Future()
def _cancel(self) -> None:
if not self._unsubscribed_future.done():
self._unsubscribed_future.set_result(True)
def __aiter__(self) -> _SubscriptionMessageIterator:
return self
async def __anext__(self) -> Msg:
get_task = asyncio.get_running_loop().create_task(self._queue.get())
tasks: List[asyncio.Future] = [get_task, self._unsubscribed_future]
finished, _ = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
sub = self._sub
if get_task in finished:
self._queue.task_done()
msg = get_task.result()
self._sub._pending_size -= len(msg.data)
# Unblock the iterator in case it has already received enough messages.
if sub._max_msgs > 0 and sub._received >= sub._max_msgs:
self._cancel()
return msg
elif self._unsubscribed_future.done():
get_task.cancel()
raise StopAsyncIteration