From 4da1969c198276fd6af0e768128796ef634b7fac Mon Sep 17 00:00:00 2001 From: Takuma Iwaki Date: Thu, 29 May 2025 03:06:25 +0900 Subject: [PATCH 1/5] FEAT: type hint --- channels_redis/core.py | 102 ++++++++++++++++++++-------------- channels_redis/pubsub.py | 44 +++++++++------ channels_redis/py.typed | 1 + channels_redis/serializers.py | 77 +++++++++++++++---------- channels_redis/utils.py | 36 ++++++++---- conftest.py | 10 ++++ tests/test_core.py | 7 --- 7 files changed, 171 insertions(+), 106 deletions(-) create mode 100644 channels_redis/py.typed create mode 100644 conftest.py diff --git a/channels_redis/core.py b/channels_redis/core.py index a230081..6862d13 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -19,6 +19,13 @@ create_pool, decode_hosts, ) +from typing import TYPE_CHECKING, Dict, List, Tuple, Union, Optional + +if TYPE_CHECKING: + from redis.asyncio.connection import ConnectionPool + from redis.asyncio.client import Redis + from .core import RedisChannelLayer + from typing_extensions import Buffer logger = logging.getLogger(__name__) @@ -32,23 +39,27 @@ class ChannelLock: """ def __init__(self): - self.locks = collections.defaultdict(asyncio.Lock) - self.wait_counts = collections.defaultdict(int) + self.locks: "collections.defaultdict[str, asyncio.Lock]" = ( + collections.defaultdict(asyncio.Lock) + ) + self.wait_counts: "collections.defaultdict[str, int]" = collections.defaultdict( + int + ) - async def acquire(self, channel): + async def acquire(self, channel: str) -> bool: """ Acquire the lock for the given channel. """ self.wait_counts[channel] += 1 return await self.locks[channel].acquire() - def locked(self, channel): + def locked(self, channel: str) -> bool: """ Return ``True`` if the lock for the given channel is acquired. """ return self.locks[channel].locked() - def release(self, channel): + def release(self, channel: str): """ Release the lock for the given channel. """ @@ -73,12 +84,12 @@ def put_nowait(self, item): class RedisLoopLayer: - def __init__(self, channel_layer): + def __init__(self, channel_layer: "RedisChannelLayer"): self._lock = asyncio.Lock() self.channel_layer = channel_layer - self._connections = {} + self._connections: "Dict[int, Redis]" = {} - def get_connection(self, index): + def get_connection(self, index: int) -> "Redis": if index not in self._connections: pool = self.channel_layer.create_pool(index) self._connections[index] = aioredis.Redis(connection_pool=pool) @@ -134,7 +145,7 @@ def __init__( symmetric_encryption_keys=symmetric_encryption_keys, ) # Cached redis connection pools and the event loop they are from - self._layers = {} + self._layers: "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {} # Normal channels choose a host index by cycling through the available hosts self._receive_index_generator = itertools.cycle(range(len(self.hosts))) self._send_index_generator = itertools.cycle(range(len(self.hosts))) @@ -143,33 +154,33 @@ def __init__( # Number of coroutines trying to receive right now self.receive_count = 0 # The receive lock - self.receive_lock = None + self.receive_lock: "Optional[asyncio.Lock]" = None # Event loop they are trying to receive on - self.receive_event_loop = None + self.receive_event_loop: "Optional[asyncio.AbstractEventLoop]" = None # Buffered messages by process-local channel name - self.receive_buffer = collections.defaultdict( - functools.partial(BoundedQueue, self.capacity) + self.receive_buffer: "collections.defaultdict[str, BoundedQueue]" = ( + collections.defaultdict(functools.partial(BoundedQueue, self.capacity)) ) # Detached channel cleanup tasks - self.receive_cleaners = [] + self.receive_cleaners: "List[asyncio.Task]" = [] # Per-channel cleanup locks to prevent a receive starting and moving # a message back into the main queue before its cleanup has completed self.receive_clean_locks = ChannelLock() - def create_pool(self, index): + def create_pool(self, index: int) -> "ConnectionPool": return create_pool(self.hosts[index]) ### Channel layer API ### extensions = ["groups", "flush"] - async def send(self, channel, message): + async def send(self, channel: str, message): """ Send a message onto a (general or specific) channel. """ # Typecheck assert isinstance(message, dict), "message is not a dict" - assert self.valid_channel_name(channel), "Channel name not valid" + assert self.require_valid_channel_name(channel), "Channel name not valid" # Make sure the message does not contain reserved keys assert "__asgi_channel__" not in message # If it's a process-local channel, strip off local part and stick full name in message @@ -203,13 +214,15 @@ async def send(self, channel, message): await connection.zadd(channel_key, {self.serialize(message): time.time()}) await connection.expire(channel_key, int(self.expiry)) - def _backup_channel_name(self, channel): + def _backup_channel_name(self, channel: str) -> str: """ Construct the key used as a backup queue for the given channel. """ return channel + "$inflight" - async def _brpop_with_clean(self, index, channel, timeout): + async def _brpop_with_clean( + self, index: int, channel: str, timeout: "Union[int, float, bytes, str]" + ): """ Perform a Redis BRPOP and manage the backup processing queue. In case of cancellation, make sure the message is not lost. @@ -240,7 +253,7 @@ async def _brpop_with_clean(self, index, channel, timeout): return member - async def _clean_receive_backup(self, index, channel): + async def _clean_receive_backup(self, index: int, channel: str): """ Pop the oldest message off the channel backup queue. The result isn't interesting as it was already processed. @@ -248,7 +261,7 @@ async def _clean_receive_backup(self, index, channel): connection = self.connection(index) await connection.zpopmin(self._backup_channel_name(channel)) - async def receive(self, channel): + async def receive(self, channel: str): """ Receive the first message that arrives on the channel. If more than one coroutine waits on the same channel, the first waiter @@ -256,7 +269,7 @@ async def receive(self, channel): """ # Make sure the channel name is valid then get the non-local part # and thus its index - assert self.valid_channel_name(channel) + assert self.require_valid_channel_name(channel) if "!" in channel: real_channel = self.non_local_name(channel) assert real_channel.endswith( @@ -372,12 +385,14 @@ async def receive(self, channel): # Do a plain direct receive return (await self.receive_single(channel))[1] - async def receive_single(self, channel): + async def receive_single(self, channel: str) -> "Tuple": """ Receives a single message off of the channel and returns it. """ # Check channel name - assert self.valid_channel_name(channel, receive=True), "Channel name invalid" + assert self.require_valid_channel_name( + channel, receive=True + ), "Channel name invalid" # Work out the connection to use if "!" in channel: assert channel.endswith("!") @@ -408,7 +423,7 @@ async def receive_single(self, channel): ) self.receive_cleaners.append(cleaner) - def _cleanup_done(cleaner): + def _cleanup_done(cleaner: "asyncio.Task"): self.receive_cleaners.remove(cleaner) self.receive_clean_locks.release(channel_key) @@ -427,7 +442,7 @@ def _cleanup_done(cleaner): del message["__asgi_channel__"] return channel, message - async def new_channel(self, prefix="specific"): + async def new_channel(self, prefix: str = "specific") -> str: """ Returns a new channel name that can be used by something in our process as a specific channel. @@ -477,13 +492,13 @@ async def wait_received(self): ### Groups extension ### - async def group_add(self, group, channel): + async def group_add(self, group: str, channel: str): """ Adds the channel name to a group. """ # Check the inputs - assert self.valid_group_name(group), "Group name not valid" - assert self.valid_channel_name(channel), "Channel name not valid" + assert self.require_valid_group_name(group), True + assert self.require_valid_channel_name(channel), True # Get a connection to the right shard group_key = self._group_key(group) connection = self.connection(self.consistent_hash(group)) @@ -493,22 +508,22 @@ async def group_add(self, group, channel): # it at this point is guaranteed to expire before that await connection.expire(group_key, self.group_expiry) - async def group_discard(self, group, channel): + async def group_discard(self, group: str, channel: str): """ Removes the channel from the named group if it is in the group; does nothing otherwise (does not error) """ - assert self.valid_group_name(group), "Group name not valid" - assert self.valid_channel_name(channel), "Channel name not valid" + assert self.require_valid_group_name(group), "Group name not valid" + assert self.require_valid_channel_name(channel), "Channel name not valid" key = self._group_key(group) connection = self.connection(self.consistent_hash(group)) await connection.zrem(key, channel) - async def group_send(self, group, message): + async def group_send(self, group: str, message): """ Sends a message to the entire group. """ - assert self.valid_group_name(group), "Group name not valid" + assert self.require_valid_group_name(group), "Group name not valid" # Retrieve list of all channel names key = self._group_key(group) connection = self.connection(self.consistent_hash(group)) @@ -573,7 +588,12 @@ async def group_send(self, group, message): channels_over_capacity = await connection.eval( group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args ) - if channels_over_capacity > 0: + _channels_over_capacity = -1 + try: + _channels_over_capacity = float(channels_over_capacity) + except Exception: + pass + if _channels_over_capacity > 0: logger.info( "%s of %s channels over capacity in group %s", channels_over_capacity, @@ -631,21 +651,19 @@ def _map_channel_keys_to_connection(self, channel_names, message): channel_key_to_capacity, ) - def _group_key(self, group): + def _group_key(self, group: str) -> bytes: """ Common function to make the storage key for the group. """ return f"{self.prefix}:group:{group}".encode("utf8") - ### Serialization ### - - def serialize(self, message): + def serialize(self, message) -> bytes: """ Serializes message to a byte string. """ return self._serializer.serialize(message) - def deserialize(self, message): + def deserialize(self, message: bytes): """ Deserializes from a byte string. """ @@ -653,7 +671,7 @@ def deserialize(self, message): ### Internal functions ### - def consistent_hash(self, value): + def consistent_hash(self, value: "Union[str, Buffer]") -> int: return _consistent_hash(value, self.ring_size) def __str__(self): @@ -661,7 +679,7 @@ def __str__(self): ### Connection handling ### - def connection(self, index): + def connection(self, index: int) -> "Redis": """ Returns the correct connection for the index given. Lazily instantiates pools. diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index a80e12d..e9958ab 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -13,11 +13,15 @@ create_pool, decode_hosts, ) +from typing import TYPE_CHECKING, Dict, Union, Optional, Any, Iterable +if TYPE_CHECKING: + from redis.asyncio.client import Redis + from typing_extensions import Buffer logger = logging.getLogger(__name__) -async def _async_proxy(obj, name, *args, **kwargs): +async def _async_proxy(obj: "RedisPubSubChannelLayer", name: "str", *args, **kwargs): # Must be defined as a function and not a method due to # https://bugs.python.org/issue38364 layer = obj._get_layer() @@ -28,20 +32,20 @@ class RedisPubSubChannelLayer: def __init__( self, *args, - symmetric_encryption_keys=None, + symmetric_encryption_keys: "Optional[Iterable[Union[str, Buffer]]]" = None, serializer_format="msgpack", **kwargs, - ) -> None: + ): self._args = args self._kwargs = kwargs - self._layers = {} + self._layers: "Dict[asyncio.AbstractEventLoop, RedisPubSubLoopLayer]" = {} # serialization self._serializer = registry.get_serializer( serializer_format, symmetric_encryption_keys=symmetric_encryption_keys, ) - def __getattr__(self, name): + def __getattr__(self, name: str): if name in ( "new_channel", "send", @@ -55,19 +59,19 @@ def __getattr__(self, name): else: return getattr(self._get_layer(), name) - def serialize(self, message): + def serialize(self, message) -> bytes: """ Serializes message to a byte string. """ return self._serializer.serialize(message) - def deserialize(self, message): + def deserialize(self, message: bytes): """ Deserializes from a byte string. """ return self._serializer.deserialize(message) - def _get_layer(self): + def _get_layer(self) -> "RedisPubSubLoopLayer": loop = asyncio.get_running_loop() try: @@ -91,11 +95,11 @@ class RedisPubSubLoopLayer: def __init__( self, - hosts=None, + hosts: "Union[Iterable, str, bytes, None]" = None, prefix="asgi", on_disconnect=None, on_reconnect=None, - channel_layer=None, + channel_layer: "Optional[RedisPubSubChannelLayer]" = None, **kwargs, ): self.prefix = prefix @@ -106,7 +110,7 @@ def __init__( # Each consumer gets its own *specific* channel, created with the `new_channel()` method. # This dict maps `channel_name` to a queue of messages for that channel. - self.channels = {} + self.channels: "Dict[Any, asyncio.Queue]" = {} # A channel can subscribe to zero or more groups. # This dict maps `group_name` to set of channel names who are subscribed to that group. @@ -117,13 +121,15 @@ def __init__( RedisSingleShardConnection(host, self) for host in decode_hosts(hosts) ] - def _get_shard(self, channel_or_group_name): + def _get_shard( + self, channel_or_group_name: "Union[str, Buffer]" + ) -> "RedisSingleShardConnection": """ Return the shard that is used exclusively for this channel or group. """ return self._shards[_consistent_hash(channel_or_group_name, len(self._shards))] - def _get_group_channel_name(self, group): + def _get_group_channel_name(self, group) -> str: """ Return the channel name used by a group. Includes '__group__' in the returned @@ -259,16 +265,18 @@ async def flush(self): class RedisSingleShardConnection: - def __init__(self, host, channel_layer): + def __init__(self, host: "Dict[str, Any]", channel_layer: "RedisPubSubLoopLayer"): self.host = host self.channel_layer = channel_layer self._subscribed_to = set() self._lock = asyncio.Lock() - self._redis = None + self._redis: "Optional[Redis]" = None self._pubsub = None - self._receive_task = None + self._receive_task: "Optional[asyncio.Task]" = None - async def publish(self, channel, message): + async def publish( + self, channel: "Union[str, bytes]", message: "Union[str, bytes, int, float]" + ): async with self._lock: self._ensure_redis() await self._redis.publish(channel, message) @@ -327,7 +335,7 @@ async def _do_receiving(self): logger.exception("Unexpected exception in receive task") await asyncio.sleep(1) - def _receive_message(self, message): + def _receive_message(self, message: "Optional[Dict]"): if message is not None: name = message["channel"] data = message["data"] diff --git a/channels_redis/py.typed b/channels_redis/py.typed new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/channels_redis/py.typed @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/channels_redis/serializers.py b/channels_redis/serializers.py index b981797..8f497f5 100644 --- a/channels_redis/serializers.py +++ b/channels_redis/serializers.py @@ -3,11 +3,18 @@ import hashlib import json import random +from typing import TYPE_CHECKING, Dict, Union, Optional, Any, Iterable, Type, cast + +if TYPE_CHECKING: + from typing_extensions import Buffer try: from cryptography.fernet import Fernet, MultiFernet + + _MultiFernet = MultiFernet + _Fernet = Fernet except ImportError: - MultiFernet = Fernet = None + _MultiFernet = _Fernet = None class SerializerDoesNotExist(KeyError): @@ -17,36 +24,42 @@ class SerializerDoesNotExist(KeyError): class BaseMessageSerializer(abc.ABC): def __init__( self, - symmetric_encryption_keys=None, - random_prefix_length=0, - expiry=None, + symmetric_encryption_keys: "Optional[Iterable[Union[str, Buffer]]]" = None, + random_prefix_length: int = 0, + expiry: "Optional[int]" = None, ): self.random_prefix_length = random_prefix_length self.expiry = expiry + self.crypter: "Optional[MultiFernet]" = None # Set up any encryption objects self._setup_encryption(symmetric_encryption_keys) - def _setup_encryption(self, symmetric_encryption_keys): + def _setup_encryption( + self, + symmetric_encryption_keys: "Optional[Union[Iterable[Union[str, Buffer]], str, bytes]]", + ): # See if we can do encryption if they asked - if symmetric_encryption_keys: - if isinstance(symmetric_encryption_keys, (str, bytes)): - raise ValueError( - "symmetric_encryption_keys must be a list of possible keys" - ) - if MultiFernet is None: - raise ValueError( - "Cannot run with encryption without 'cryptography' installed." - ) - sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys] - self.crypter = MultiFernet(sub_fernets) - else: + if not symmetric_encryption_keys: self.crypter = None + return - def make_fernet(self, key): + if isinstance(symmetric_encryption_keys, (str, bytes)): + raise ValueError( + "symmetric_encryption_keys must be a list of possible keys" + ) + keys = cast("Iterable[Union[str, Buffer]]", symmetric_encryption_keys) + if _MultiFernet is None: + raise ValueError( + "Cannot run with encryption without 'cryptography' installed." + ) + sub_fernets = [self.make_fernet(key) for key in keys] + self.crypter = _MultiFernet(sub_fernets) + + def make_fernet(self, key: "Union[str, Buffer]") -> "Fernet": """ Given a single encryption key, returns a Fernet instance using it. """ - if Fernet is None: + if _Fernet is None: raise ValueError( "Cannot run with encryption without 'cryptography' installed." ) @@ -54,7 +67,7 @@ def make_fernet(self, key): if isinstance(key, str): key = key.encode("utf-8") formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest()) - return Fernet(formatted_key) + return _Fernet(formatted_key) @abc.abstractmethod def as_bytes(self, message, *args, **kwargs): @@ -64,7 +77,7 @@ def as_bytes(self, message, *args, **kwargs): def from_bytes(self, message, *args, **kwargs): raise NotImplementedError - def serialize(self, message): + def serialize(self, message) -> bytes: """ Serializes message to a byte string. """ @@ -82,7 +95,7 @@ def serialize(self, message): ) return message - def deserialize(self, message): + def deserialize(self, message: bytes): """ Deserializes from a byte string. """ @@ -97,7 +110,7 @@ def deserialize(self, message): class MissingSerializer(BaseMessageSerializer): - exception = None + exception: "Exception" = Exception(None) def __init__(self, *args, **kwargs): raise self.exception @@ -108,7 +121,7 @@ class JSONSerializer(BaseMessageSerializer): # thus we must force bytes conversion # we use UTF-8 since it is the recommended encoding for interoperability # see https://docs.python.org/3/library/json.html#character-encodings - def as_bytes(self, message, *args, **kwargs): + def as_bytes(self, message, *args, **kwargs) -> bytes: message = json.dumps(message, *args, **kwargs) return message.encode("utf-8") @@ -116,19 +129,23 @@ def as_bytes(self, message, *args, **kwargs): # code ready for a future in which msgpack may become an optional dependency +MsgPackSerializer: "Optional[Type[BaseMessageSerializer]]" = None try: import msgpack except ImportError as exc: - class MsgPackSerializer(MissingSerializer): + class _MsgPackSerializer(MissingSerializer): exception = exc + MsgPackSerializer = _MsgPackSerializer else: - class MsgPackSerializer(BaseMessageSerializer): + class __MsgPackSerializer(BaseMessageSerializer): as_bytes = staticmethod(msgpack.packb) from_bytes = staticmethod(msgpack.unpackb) + MsgPackSerializer = __MsgPackSerializer + class SerializersRegistry: """ @@ -136,9 +153,11 @@ class SerializersRegistry: """ def __init__(self): - self._registry = {} + self._registry: "Dict[Any, Type[BaseMessageSerializer]]" = {} - def register_serializer(self, format, serializer_class): + def register_serializer( + self, format, serializer_class: "Type[BaseMessageSerializer]" + ): """ Register a new serializer for given format """ @@ -155,7 +174,7 @@ def register_serializer(self, format, serializer_class): self._registry[format] = serializer_class - def get_serializer(self, format, *args, **kwargs): + def get_serializer(self, format, *args, **kwargs) -> "BaseMessageSerializer": try: serializer_class = self._registry[format] except KeyError: diff --git a/channels_redis/utils.py b/channels_redis/utils.py index 6f15050..49712c4 100644 --- a/channels_redis/utils.py +++ b/channels_redis/utils.py @@ -1,10 +1,21 @@ import binascii import types - from redis import asyncio as aioredis +from redis.asyncio import sentinel as redis_sentinel +from typing import TYPE_CHECKING, Dict, List, Union, Any, Iterable + +if TYPE_CHECKING: + from typing_extensions import Buffer + + from redis.asyncio.connection import ConnectionPool + from redis.asyncio.client import Redis + from asyncio import AbstractEventLoop + + from .pubsub import RedisPubSubChannelLayer + from .core import RedisChannelLayer -def _consistent_hash(value, ring_size): +def _consistent_hash(value: "Union[str, Buffer]", ring_size: int) -> int: """ Maps the value to a node value between 0 and 4095 using CRC, then down to one of the ring nodes. @@ -20,7 +31,10 @@ def _consistent_hash(value, ring_size): return int(bigval / ring_divisor) -def _wrap_close(proxy, loop): +def _wrap_close( + proxy: "Union[RedisPubSubChannelLayer, RedisChannelLayer]", + loop: "AbstractEventLoop", +): original_impl = loop.close def _wrapper(self, *args, **kwargs): @@ -35,7 +49,7 @@ def _wrapper(self, *args, **kwargs): loop.close = types.MethodType(_wrapper, loop) -async def _close_redis(connection): +async def _close_redis(connection: "Redis"): """ Handle compatibility with redis-py 4.x and 5.x close methods """ @@ -45,7 +59,9 @@ async def _close_redis(connection): await connection.close(close_connection_pool=True) -def decode_hosts(hosts): +def decode_hosts( + hosts: "Union[Iterable, str, bytes, None]", +) -> "List[Dict]": """ Takes the value of the "hosts" argument and returns a list of kwargs to use for the Redis connection constructor. @@ -60,7 +76,7 @@ def decode_hosts(hosts): ) # Decode each hosts entry into a kwargs dict - result = [] + result: "List[Dict]" = [] for entry in hosts: if isinstance(entry, dict): result.append(entry) @@ -71,7 +87,7 @@ def decode_hosts(hosts): return result -def create_pool(host): +def create_pool(host: "Dict[str, Any]") -> "ConnectionPool": """ Takes the value of the "host" argument and returns a suited connection pool to the corresponding redis instance. @@ -86,10 +102,10 @@ def create_pool(host): if master_name is not None: sentinels = host.pop("sentinels") sentinel_kwargs = host.pop("sentinel_kwargs", None) - return aioredis.sentinel.SentinelConnectionPool( + return redis_sentinel.SentinelConnectionPool( master_name, - aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs), - **host + redis_sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs), + **host, ) return aioredis.ConnectionPool(**host) diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..bae5ac4 --- /dev/null +++ b/conftest.py @@ -0,0 +1,10 @@ +import asyncio +import pytest + + +@pytest.fixture +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() diff --git a/tests/test_core.py b/tests/test_core.py index e5bda1c..92f6923 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -419,12 +419,6 @@ def test_repeated_group_send_with_async_to_sync(channel_layer): pytest.fail(f"repeated async_to_sync wrapped group_send calls raised {exc}") -@pytest.mark.xfail( - reason=""" -Fails with error in redis-py: int() argument must be a string, a bytes-like -object or a real number, not 'NoneType'. Refs: #348 -""" -) @pytest.mark.asyncio async def test_receive_cancel(channel_layer): """ @@ -557,7 +551,6 @@ async def test_message_expiry__group_send(channel_layer): await channel_layer.receive(channel_name) -@pytest.mark.xfail(reason="Fails with timeout. Refs: #348") @pytest.mark.asyncio async def test_message_expiry__group_send__one_channel_expires_message(channel_layer): expiry = 3 From 21ff64dc258831dd3c765adb86094aad919cf006 Mon Sep 17 00:00:00 2001 From: Takuma Iwaki Date: Fri, 30 May 2025 02:50:12 +0900 Subject: [PATCH 2/5] FIX REVIEW --- channels_redis/core.py | 48 +++++++++++++++++------------------ channels_redis/pubsub.py | 37 ++++++++++++++++----------- channels_redis/serializers.py | 30 +++++++++++++--------- channels_redis/utils.py | 16 ++++++------ 4 files changed, 71 insertions(+), 60 deletions(-) diff --git a/channels_redis/core.py b/channels_redis/core.py index 6862d13..c5613fa 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -19,12 +19,12 @@ create_pool, decode_hosts, ) -from typing import TYPE_CHECKING, Dict, List, Tuple, Union, Optional -if TYPE_CHECKING: +import typing + +if typing.TYPE_CHECKING: from redis.asyncio.connection import ConnectionPool from redis.asyncio.client import Redis - from .core import RedisChannelLayer from typing_extensions import Buffer logger = logging.getLogger(__name__) @@ -39,10 +39,10 @@ class ChannelLock: """ def __init__(self): - self.locks: "collections.defaultdict[str, asyncio.Lock]" = ( + self.locks: collections.defaultdict[str, asyncio.Lock] = ( collections.defaultdict(asyncio.Lock) ) - self.wait_counts: "collections.defaultdict[str, int]" = collections.defaultdict( + self.wait_counts: collections.defaultdict[str, int] = collections.defaultdict( int ) @@ -87,7 +87,7 @@ class RedisLoopLayer: def __init__(self, channel_layer: "RedisChannelLayer"): self._lock = asyncio.Lock() self.channel_layer = channel_layer - self._connections: "Dict[int, Redis]" = {} + self._connections: typing.Dict[int, "Redis"] = {} def get_connection(self, index: int) -> "Redis": if index not in self._connections: @@ -145,7 +145,7 @@ def __init__( symmetric_encryption_keys=symmetric_encryption_keys, ) # Cached redis connection pools and the event loop they are from - self._layers: "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {} + self._layers: typing.Dict[asyncio.AbstractEventLoop, "RedisLoopLayer"] = {} # Normal channels choose a host index by cycling through the available hosts self._receive_index_generator = itertools.cycle(range(len(self.hosts))) self._send_index_generator = itertools.cycle(range(len(self.hosts))) @@ -154,15 +154,15 @@ def __init__( # Number of coroutines trying to receive right now self.receive_count = 0 # The receive lock - self.receive_lock: "Optional[asyncio.Lock]" = None + self.receive_lock: typing.Optional[asyncio.Lock] = None # Event loop they are trying to receive on - self.receive_event_loop: "Optional[asyncio.AbstractEventLoop]" = None + self.receive_event_loop: typing.Optional[asyncio.AbstractEventLoop] = None # Buffered messages by process-local channel name - self.receive_buffer: "collections.defaultdict[str, BoundedQueue]" = ( + self.receive_buffer: collections.defaultdict[str, BoundedQueue] = ( collections.defaultdict(functools.partial(BoundedQueue, self.capacity)) ) # Detached channel cleanup tasks - self.receive_cleaners: "List[asyncio.Task]" = [] + self.receive_cleaners: typing.List[asyncio.Task] = [] # Per-channel cleanup locks to prevent a receive starting and moving # a message back into the main queue before its cleanup has completed self.receive_clean_locks = ChannelLock() @@ -180,7 +180,7 @@ async def send(self, channel: str, message): """ # Typecheck assert isinstance(message, dict), "message is not a dict" - assert self.require_valid_channel_name(channel), "Channel name not valid" + assert self.valid_channel_name(channel), "Channel name not valid" # Make sure the message does not contain reserved keys assert "__asgi_channel__" not in message # If it's a process-local channel, strip off local part and stick full name in message @@ -221,7 +221,7 @@ def _backup_channel_name(self, channel: str) -> str: return channel + "$inflight" async def _brpop_with_clean( - self, index: int, channel: str, timeout: "Union[int, float, bytes, str]" + self, index: int, channel: str, timeout: typing.Union[int, float, bytes, str] ): """ Perform a Redis BRPOP and manage the backup processing queue. @@ -269,7 +269,7 @@ async def receive(self, channel: str): """ # Make sure the channel name is valid then get the non-local part # and thus its index - assert self.require_valid_channel_name(channel) + assert self.valid_channel_name(channel) if "!" in channel: real_channel = self.non_local_name(channel) assert real_channel.endswith( @@ -385,14 +385,12 @@ async def receive(self, channel: str): # Do a plain direct receive return (await self.receive_single(channel))[1] - async def receive_single(self, channel: str) -> "Tuple": + async def receive_single(self, channel: str) -> typing.Tuple: """ Receives a single message off of the channel and returns it. """ # Check channel name - assert self.require_valid_channel_name( - channel, receive=True - ), "Channel name invalid" + assert self.valid_channel_name(channel, receive=True), "Channel name invalid" # Work out the connection to use if "!" in channel: assert channel.endswith("!") @@ -423,7 +421,7 @@ async def receive_single(self, channel: str) -> "Tuple": ) self.receive_cleaners.append(cleaner) - def _cleanup_done(cleaner: "asyncio.Task"): + def _cleanup_done(cleaner: asyncio.Task): self.receive_cleaners.remove(cleaner) self.receive_clean_locks.release(channel_key) @@ -497,8 +495,8 @@ async def group_add(self, group: str, channel: str): Adds the channel name to a group. """ # Check the inputs - assert self.require_valid_group_name(group), True - assert self.require_valid_channel_name(channel), True + assert self.valid_group_name(group), True + assert self.valid_channel_name(channel), True # Get a connection to the right shard group_key = self._group_key(group) connection = self.connection(self.consistent_hash(group)) @@ -513,8 +511,8 @@ async def group_discard(self, group: str, channel: str): Removes the channel from the named group if it is in the group; does nothing otherwise (does not error) """ - assert self.require_valid_group_name(group), "Group name not valid" - assert self.require_valid_channel_name(channel), "Channel name not valid" + assert self.valid_group_name(group), "Group name not valid" + assert self.valid_channel_name(channel), "Channel name not valid" key = self._group_key(group) connection = self.connection(self.consistent_hash(group)) await connection.zrem(key, channel) @@ -523,7 +521,7 @@ async def group_send(self, group: str, message): """ Sends a message to the entire group. """ - assert self.require_valid_group_name(group), "Group name not valid" + assert self.valid_group_name(group), "Group name not valid" # Retrieve list of all channel names key = self._group_key(group) connection = self.connection(self.consistent_hash(group)) @@ -671,7 +669,7 @@ def deserialize(self, message: bytes): ### Internal functions ### - def consistent_hash(self, value: "Union[str, Buffer]") -> int: + def consistent_hash(self, value: typing.Union[str, "Buffer"]) -> int: return _consistent_hash(value, self.ring_size) def __str__(self): diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index e9958ab..5af1ab4 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -13,15 +13,16 @@ create_pool, decode_hosts, ) -from typing import TYPE_CHECKING, Dict, Union, Optional, Any, Iterable -if TYPE_CHECKING: +import typing + +if typing.TYPE_CHECKING: from redis.asyncio.client import Redis from typing_extensions import Buffer logger = logging.getLogger(__name__) -async def _async_proxy(obj: "RedisPubSubChannelLayer", name: "str", *args, **kwargs): +async def _async_proxy(obj: "RedisPubSubChannelLayer", name: str, *args, **kwargs): # Must be defined as a function and not a method due to # https://bugs.python.org/issue38364 layer = obj._get_layer() @@ -32,13 +33,15 @@ class RedisPubSubChannelLayer: def __init__( self, *args, - symmetric_encryption_keys: "Optional[Iterable[Union[str, Buffer]]]" = None, + symmetric_encryption_keys: typing.Optional[ + typing.Iterable[typing.Union[str, "Buffer"]] + ] = None, serializer_format="msgpack", **kwargs, ): self._args = args self._kwargs = kwargs - self._layers: "Dict[asyncio.AbstractEventLoop, RedisPubSubLoopLayer]" = {} + self._layers: typing.Dict[asyncio.AbstractEventLoop, RedisPubSubLoopLayer] = {} # serialization self._serializer = registry.get_serializer( serializer_format, @@ -95,11 +98,11 @@ class RedisPubSubLoopLayer: def __init__( self, - hosts: "Union[Iterable, str, bytes, None]" = None, - prefix="asgi", + hosts: typing.Union[typing.Iterable, str, bytes, None] = None, + prefix: str = "asgi", on_disconnect=None, on_reconnect=None, - channel_layer: "Optional[RedisPubSubChannelLayer]" = None, + channel_layer: typing.Optional[RedisPubSubChannelLayer] = None, **kwargs, ): self.prefix = prefix @@ -110,7 +113,7 @@ def __init__( # Each consumer gets its own *specific* channel, created with the `new_channel()` method. # This dict maps `channel_name` to a queue of messages for that channel. - self.channels: "Dict[Any, asyncio.Queue]" = {} + self.channels: typing.Dict[typing.Any, asyncio.Queue] = {} # A channel can subscribe to zero or more groups. # This dict maps `group_name` to set of channel names who are subscribed to that group. @@ -122,7 +125,7 @@ def __init__( ] def _get_shard( - self, channel_or_group_name: "Union[str, Buffer]" + self, channel_or_group_name: typing.Union[str, "Buffer"] ) -> "RedisSingleShardConnection": """ Return the shard that is used exclusively for this channel or group. @@ -265,17 +268,21 @@ async def flush(self): class RedisSingleShardConnection: - def __init__(self, host: "Dict[str, Any]", channel_layer: "RedisPubSubLoopLayer"): + def __init__( + self, host: typing.Dict[str, typing.Any], channel_layer: RedisPubSubLoopLayer + ): self.host = host self.channel_layer = channel_layer self._subscribed_to = set() self._lock = asyncio.Lock() - self._redis: "Optional[Redis]" = None + self._redis: typing.Optional["Redis"] = None self._pubsub = None - self._receive_task: "Optional[asyncio.Task]" = None + self._receive_task: typing.Optional[asyncio.Task] = None async def publish( - self, channel: "Union[str, bytes]", message: "Union[str, bytes, int, float]" + self, + channel: typing.Union[str, bytes], + message: typing.Union[str, bytes, int, float], ): async with self._lock: self._ensure_redis() @@ -335,7 +342,7 @@ async def _do_receiving(self): logger.exception("Unexpected exception in receive task") await asyncio.sleep(1) - def _receive_message(self, message: "Optional[Dict]"): + def _receive_message(self, message: typing.Optional[typing.Dict]): if message is not None: name = message["channel"] data = message["data"] diff --git a/channels_redis/serializers.py b/channels_redis/serializers.py index 8f497f5..a813c3c 100644 --- a/channels_redis/serializers.py +++ b/channels_redis/serializers.py @@ -3,9 +3,9 @@ import hashlib import json import random -from typing import TYPE_CHECKING, Dict, Union, Optional, Any, Iterable, Type, cast +import typing -if TYPE_CHECKING: +if typing.TYPE_CHECKING: from typing_extensions import Buffer try: @@ -24,19 +24,23 @@ class SerializerDoesNotExist(KeyError): class BaseMessageSerializer(abc.ABC): def __init__( self, - symmetric_encryption_keys: "Optional[Iterable[Union[str, Buffer]]]" = None, + symmetric_encryption_keys: typing.Optional[ + typing.Iterable[typing.Union[str, "Buffer"]] + ] = None, random_prefix_length: int = 0, - expiry: "Optional[int]" = None, + expiry: typing.Optional[int] = None, ): self.random_prefix_length = random_prefix_length self.expiry = expiry - self.crypter: "Optional[MultiFernet]" = None + self.crypter: typing.Optional["MultiFernet"] = None # Set up any encryption objects self._setup_encryption(symmetric_encryption_keys) def _setup_encryption( self, - symmetric_encryption_keys: "Optional[Union[Iterable[Union[str, Buffer]], str, bytes]]", + symmetric_encryption_keys: typing.Optional[ + typing.Union[typing.Iterable[typing.Union[str, "Buffer"]], str, bytes] + ], ): # See if we can do encryption if they asked if not symmetric_encryption_keys: @@ -47,7 +51,9 @@ def _setup_encryption( raise ValueError( "symmetric_encryption_keys must be a list of possible keys" ) - keys = cast("Iterable[Union[str, Buffer]]", symmetric_encryption_keys) + keys = typing.cast( + typing.Iterable[typing.Union[str, "Buffer"]], symmetric_encryption_keys + ) if _MultiFernet is None: raise ValueError( "Cannot run with encryption without 'cryptography' installed." @@ -55,7 +61,7 @@ def _setup_encryption( sub_fernets = [self.make_fernet(key) for key in keys] self.crypter = _MultiFernet(sub_fernets) - def make_fernet(self, key: "Union[str, Buffer]") -> "Fernet": + def make_fernet(self, key: typing.Union[str, "Buffer"]) -> "Fernet": """ Given a single encryption key, returns a Fernet instance using it. """ @@ -129,7 +135,7 @@ def as_bytes(self, message, *args, **kwargs) -> bytes: # code ready for a future in which msgpack may become an optional dependency -MsgPackSerializer: "Optional[Type[BaseMessageSerializer]]" = None +MsgPackSerializer: typing.Optional[typing.Type[BaseMessageSerializer]] = None try: import msgpack except ImportError as exc: @@ -153,10 +159,10 @@ class SerializersRegistry: """ def __init__(self): - self._registry: "Dict[Any, Type[BaseMessageSerializer]]" = {} + self._registry: typing.Dict[typing.Any, typing.Type[BaseMessageSerializer]] = {} def register_serializer( - self, format, serializer_class: "Type[BaseMessageSerializer]" + self, format, serializer_class: typing.Type[BaseMessageSerializer] ): """ Register a new serializer for given format @@ -174,7 +180,7 @@ def register_serializer( self._registry[format] = serializer_class - def get_serializer(self, format, *args, **kwargs) -> "BaseMessageSerializer": + def get_serializer(self, format, *args, **kwargs) -> BaseMessageSerializer: try: serializer_class = self._registry[format] except KeyError: diff --git a/channels_redis/utils.py b/channels_redis/utils.py index 49712c4..7a0d370 100644 --- a/channels_redis/utils.py +++ b/channels_redis/utils.py @@ -2,9 +2,9 @@ import types from redis import asyncio as aioredis from redis.asyncio import sentinel as redis_sentinel -from typing import TYPE_CHECKING, Dict, List, Union, Any, Iterable +import typing -if TYPE_CHECKING: +if typing.TYPE_CHECKING: from typing_extensions import Buffer from redis.asyncio.connection import ConnectionPool @@ -15,7 +15,7 @@ from .core import RedisChannelLayer -def _consistent_hash(value: "Union[str, Buffer]", ring_size: int) -> int: +def _consistent_hash(value: typing.Union[str, "Buffer"], ring_size: int) -> int: """ Maps the value to a node value between 0 and 4095 using CRC, then down to one of the ring nodes. @@ -32,7 +32,7 @@ def _consistent_hash(value: "Union[str, Buffer]", ring_size: int) -> int: def _wrap_close( - proxy: "Union[RedisPubSubChannelLayer, RedisChannelLayer]", + proxy: typing.Union["RedisPubSubChannelLayer", "RedisChannelLayer"], loop: "AbstractEventLoop", ): original_impl = loop.close @@ -60,8 +60,8 @@ async def _close_redis(connection: "Redis"): def decode_hosts( - hosts: "Union[Iterable, str, bytes, None]", -) -> "List[Dict]": + hosts: typing.Union[typing.Iterable, str, bytes, None], +) -> typing.List[typing.Dict]: """ Takes the value of the "hosts" argument and returns a list of kwargs to use for the Redis connection constructor. @@ -76,7 +76,7 @@ def decode_hosts( ) # Decode each hosts entry into a kwargs dict - result: "List[Dict]" = [] + result: typing.List[typing.Dict] = [] for entry in hosts: if isinstance(entry, dict): result.append(entry) @@ -87,7 +87,7 @@ def decode_hosts( return result -def create_pool(host: "Dict[str, Any]") -> "ConnectionPool": +def create_pool(host: typing.Dict[str, typing.Any]) -> "ConnectionPool": """ Takes the value of the "host" argument and returns a suited connection pool to the corresponding redis instance. From 10b4a42be566f4af3f98ccc7e8a52878591dea1a Mon Sep 17 00:00:00 2001 From: Takuma Iwaki Date: Tue, 10 Jun 2025 00:14:08 +0900 Subject: [PATCH 3/5] FIX: order import --- channels_redis/core.py | 5 ++--- channels_redis/pubsub.py | 3 +-- channels_redis/utils.py | 11 ++++++----- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/channels_redis/core.py b/channels_redis/core.py index c5613fa..9ba6060 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -4,6 +4,7 @@ import itertools import logging import time +import typing import uuid from redis import asyncio as aioredis @@ -20,11 +21,9 @@ decode_hosts, ) -import typing - if typing.TYPE_CHECKING: - from redis.asyncio.connection import ConnectionPool from redis.asyncio.client import Redis + from redis.asyncio.connection import ConnectionPool from typing_extensions import Buffer logger = logging.getLogger(__name__) diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index 5af1ab4..0901b3b 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -1,6 +1,7 @@ import asyncio import functools import logging +import typing import uuid from redis import asyncio as aioredis @@ -14,8 +15,6 @@ decode_hosts, ) -import typing - if typing.TYPE_CHECKING: from redis.asyncio.client import Redis from typing_extensions import Buffer diff --git a/channels_redis/utils.py b/channels_redis/utils.py index 7a0d370..7de5d37 100644 --- a/channels_redis/utils.py +++ b/channels_redis/utils.py @@ -1,18 +1,19 @@ import binascii import types +import typing + from redis import asyncio as aioredis from redis.asyncio import sentinel as redis_sentinel -import typing if typing.TYPE_CHECKING: - from typing_extensions import Buffer + from asyncio import AbstractEventLoop - from redis.asyncio.connection import ConnectionPool from redis.asyncio.client import Redis - from asyncio import AbstractEventLoop + from redis.asyncio.connection import ConnectionPool + from typing_extensions import Buffer - from .pubsub import RedisPubSubChannelLayer from .core import RedisChannelLayer + from .pubsub import RedisPubSubChannelLayer def _consistent_hash(value: typing.Union[str, "Buffer"], ring_size: int) -> int: From 1e3e55d27fad5603133e584cc52e30b19c8cb580 Mon Sep 17 00:00:00 2001 From: Takuma Iwaki Date: Tue, 10 Jun 2025 00:22:43 +0900 Subject: [PATCH 4/5] TEST --- channels_redis/core.py | 114 ++++++++++++++++++---------------- channels_redis/pubsub.py | 88 ++++++++++++++------------ channels_redis/py.typed | 1 - channels_redis/serializers.py | 89 +++++++++++++------------- channels_redis/utils.py | 14 ++--- conftest.py | 10 --- setup.cfg | 7 +++ tests/test_core.py | 7 +++ tox.ini | 6 ++ 9 files changed, 183 insertions(+), 153 deletions(-) delete mode 100644 conftest.py diff --git a/channels_redis/core.py b/channels_redis/core.py index 9ba6060..e9c5121 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -9,8 +9,8 @@ from redis import asyncio as aioredis -from channels.exceptions import ChannelFull -from channels.layers import BaseChannelLayer +from channels.exceptions import ChannelFull # type: ignore[import-untyped] +from channels.layers import BaseChannelLayer # type: ignore[import-untyped] from .serializers import registry from .utils import ( @@ -37,13 +37,11 @@ class ChannelLock: to mitigate multi-event loop problems. """ - def __init__(self): - self.locks: collections.defaultdict[str, asyncio.Lock] = ( - collections.defaultdict(asyncio.Lock) - ) - self.wait_counts: collections.defaultdict[str, int] = collections.defaultdict( - int + def __init__(self) -> None: + self.locks: typing.DefaultDict[str, asyncio.Lock] = collections.defaultdict( + asyncio.Lock ) + self.wait_counts: typing.DefaultDict[str, int] = collections.defaultdict(int) async def acquire(self, channel: str) -> bool: """ @@ -58,7 +56,7 @@ def locked(self, channel: str) -> bool: """ return self.locks[channel].locked() - def release(self, channel: str): + def release(self, channel: str) -> None: """ Release the lock for the given channel. """ @@ -70,7 +68,7 @@ def release(self, channel: str): class BoundedQueue(asyncio.Queue): - def put_nowait(self, item): + def put_nowait(self, item: typing.Any) -> None: if self.full(): # see: https://github.com/django/channels_redis/issues/212 # if we actually get into this code block, it likely means that @@ -83,7 +81,7 @@ def put_nowait(self, item): class RedisLoopLayer: - def __init__(self, channel_layer: "RedisChannelLayer"): + def __init__(self, channel_layer: "RedisChannelLayer") -> None: self._lock = asyncio.Lock() self.channel_layer = channel_layer self._connections: typing.Dict[int, "Redis"] = {} @@ -95,7 +93,7 @@ def get_connection(self, index: int) -> "Redis": return self._connections[index] - async def flush(self): + async def flush(self) -> None: async with self._lock: for index in list(self._connections): connection = self._connections.pop(index) @@ -116,15 +114,15 @@ class RedisChannelLayer(BaseChannelLayer): def __init__( self, hosts=None, - prefix="asgi", + prefix: str = "asgi", expiry=60, - group_expiry=86400, + group_expiry: int = 86400, capacity=100, channel_capacity=None, symmetric_encryption_keys=None, random_prefix_length=12, serializer_format="msgpack", - ): + ) -> None: # Store basic information self.expiry = expiry self.group_expiry = group_expiry @@ -157,11 +155,11 @@ def __init__( # Event loop they are trying to receive on self.receive_event_loop: typing.Optional[asyncio.AbstractEventLoop] = None # Buffered messages by process-local channel name - self.receive_buffer: collections.defaultdict[str, BoundedQueue] = ( + self.receive_buffer: typing.DefaultDict[str, BoundedQueue] = ( collections.defaultdict(functools.partial(BoundedQueue, self.capacity)) ) # Detached channel cleanup tasks - self.receive_cleaners: typing.List[asyncio.Task] = [] + self.receive_cleaners: typing.List["asyncio.Task[typing.Any]"] = [] # Per-channel cleanup locks to prevent a receive starting and moving # a message back into the main queue before its cleanup has completed self.receive_clean_locks = ChannelLock() @@ -173,7 +171,7 @@ def create_pool(self, index: int) -> "ConnectionPool": extensions = ["groups", "flush"] - async def send(self, channel: str, message): + async def send(self, channel: str, message: typing.Any) -> None: """ Send a message onto a (general or specific) channel. """ @@ -221,7 +219,7 @@ def _backup_channel_name(self, channel: str) -> str: async def _brpop_with_clean( self, index: int, channel: str, timeout: typing.Union[int, float, bytes, str] - ): + ) -> typing.Any: """ Perform a Redis BRPOP and manage the backup processing queue. In case of cancellation, make sure the message is not lost. @@ -240,7 +238,7 @@ async def _brpop_with_clean( connection = self.connection(index) # Cancellation here doesn't matter, we're not doing anything destructive # and the script executes atomically... - await connection.eval(cleanup_script, 0, channel, backup_queue) + await connection.eval(cleanup_script, 0, channel, backup_queue) # type: ignore[misc] # ...and it doesn't matter here either, the message will be safe in the backup. result = await connection.bzpopmin(channel, timeout=timeout) @@ -252,7 +250,7 @@ async def _brpop_with_clean( return member - async def _clean_receive_backup(self, index: int, channel: str): + async def _clean_receive_backup(self, index: int, channel: str) -> None: """ Pop the oldest message off the channel backup queue. The result isn't interesting as it was already processed. @@ -260,7 +258,7 @@ async def _clean_receive_backup(self, index: int, channel: str): connection = self.connection(index) await connection.zpopmin(self._backup_channel_name(channel)) - async def receive(self, channel: str): + async def receive(self, channel: str) -> typing.Any: """ Receive the first message that arrives on the channel. If more than one coroutine waits on the same channel, the first waiter @@ -292,11 +290,11 @@ async def receive(self, channel: str): # Wait for our message to appear message = None while self.receive_buffer[channel].empty(): - tasks = [ - self.receive_lock.acquire(), + _tasks = [ + self.receive_lock.acquire(), # type: ignore[union-attr] self.receive_buffer[channel].get(), ] - tasks = [asyncio.ensure_future(task) for task in tasks] + tasks = [asyncio.ensure_future(task) for task in _tasks] try: done, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED @@ -312,7 +310,7 @@ async def receive(self, channel: str): if not task.cancel(): assert task.done() if task.result() is True: - self.receive_lock.release() + self.receive_lock.release() # type: ignore[union-attr] raise @@ -335,7 +333,7 @@ async def receive(self, channel: str): if message or exception: if token: # We will not be receving as we already have the message. - self.receive_lock.release() + self.receive_lock.release() # type: ignore[union-attr] if exception: raise exception @@ -362,7 +360,7 @@ async def receive(self, channel: str): del self.receive_buffer[channel] raise finally: - self.receive_lock.release() + self.receive_lock.release() # type: ignore[union-attr] # We know there's a message available, because there # couldn't have been any interruption between empty() and here @@ -377,14 +375,16 @@ async def receive(self, channel: str): self.receive_count -= 1 # If we were the last out, drop the receive lock if self.receive_count == 0: - assert not self.receive_lock.locked() + assert not self.receive_lock.locked() # type: ignore[union-attr] self.receive_lock = None self.receive_event_loop = None else: # Do a plain direct receive return (await self.receive_single(channel))[1] - async def receive_single(self, channel: str) -> typing.Tuple: + async def receive_single( + self, channel: str + ) -> typing.Tuple[typing.Any, typing.Any]: """ Receives a single message off of the channel and returns it. """ @@ -420,7 +420,7 @@ async def receive_single(self, channel: str) -> typing.Tuple: ) self.receive_cleaners.append(cleaner) - def _cleanup_done(cleaner: asyncio.Task): + def _cleanup_done(cleaner: "asyncio.Task") -> None: self.receive_cleaners.remove(cleaner) self.receive_clean_locks.release(channel_key) @@ -448,7 +448,7 @@ async def new_channel(self, prefix: str = "specific") -> str: ### Flush extension ### - async def flush(self): + async def flush(self) -> None: """ Deletes all messages and groups on all shards. """ @@ -466,11 +466,11 @@ async def flush(self): # Go through each connection and remove all with prefix for i in range(self.ring_size): connection = self.connection(i) - await connection.eval(delete_prefix, 0, self.prefix + "*") + await connection.eval(delete_prefix, 0, self.prefix + "*") # type: ignore[union-attr,misc] # Now clear the pools as well await self.close_pools() - async def close_pools(self): + async def close_pools(self) -> None: """ Close all connections in the event loop pools. """ @@ -480,7 +480,7 @@ async def close_pools(self): for layer in self._layers.values(): await layer.flush() - async def wait_received(self): + async def wait_received(self) -> None: """ Wait for all channel cleanup functions to finish. """ @@ -489,13 +489,13 @@ async def wait_received(self): ### Groups extension ### - async def group_add(self, group: str, channel: str): + async def group_add(self, group: str, channel: str) -> None: """ Adds the channel name to a group. """ # Check the inputs - assert self.valid_group_name(group), True - assert self.valid_channel_name(channel), True + assert self.valid_group_name(group), "Group name not valid" + assert self.valid_channel_name(channel), "Channel name not valid" # Get a connection to the right shard group_key = self._group_key(group) connection = self.connection(self.consistent_hash(group)) @@ -505,7 +505,7 @@ async def group_add(self, group: str, channel: str): # it at this point is guaranteed to expire before that await connection.expire(group_key, self.group_expiry) - async def group_discard(self, group: str, channel: str): + async def group_discard(self, group: str, channel: str) -> None: """ Removes the channel from the named group if it is in the group; does nothing otherwise (does not error) @@ -516,7 +516,7 @@ async def group_discard(self, group: str, channel: str): connection = self.connection(self.consistent_hash(group)) await connection.zrem(key, channel) - async def group_send(self, group: str, message): + async def group_send(self, group: str, message: typing.Any) -> None: """ Sends a message to the entire group. """ @@ -540,9 +540,9 @@ async def group_send(self, group: str, message): for connection_index, channel_redis_keys in connection_to_channel_keys.items(): # Discard old messages based on expiry pipe = connection.pipeline() - for key in channel_redis_keys: + for _key in channel_redis_keys: pipe.zremrangebyscore( - key, min=0, max=int(time.time()) - int(self.expiry) + _key, min=0, max=int(time.time()) - int(self.expiry) ) await pipe.execute() @@ -582,10 +582,10 @@ async def group_send(self, group: str, message): # channel_keys does not contain a single redis key more than once connection = self.connection(connection_index) - channels_over_capacity = await connection.eval( + channels_over_capacity = await connection.eval( # type: ignore[misc] group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args ) - _channels_over_capacity = -1 + _channels_over_capacity = -1.0 try: _channels_over_capacity = float(channels_over_capacity) except Exception: @@ -598,7 +598,13 @@ async def group_send(self, group: str, message): group, ) - def _map_channel_keys_to_connection(self, channel_names, message): + def _map_channel_keys_to_connection( + self, channel_names: typing.Iterable[str], message: typing.Any + ) -> typing.Tuple[ + typing.Dict[int, typing.List[str]], + typing.Dict[str, typing.Any], + typing.Dict[str, int], + ]: """ For a list of channel names, GET @@ -611,11 +617,13 @@ def _map_channel_keys_to_connection(self, channel_names, message): """ # Connection dict keyed by index to list of redis keys mapped on that index - connection_to_channel_keys = collections.defaultdict(list) + connection_to_channel_keys: typing.Dict[int, typing.List[str]] = ( + collections.defaultdict(list) + ) # Message dict maps redis key to the message that needs to be send on that key - channel_key_to_message = dict() + channel_key_to_message: typing.Dict[str, typing.Any] = dict() # Channel key mapped to its capacity - channel_key_to_capacity = dict() + channel_key_to_capacity: typing.Dict[str, int] = dict() # For each channel for channel in channel_names: @@ -623,7 +631,7 @@ def _map_channel_keys_to_connection(self, channel_names, message): if "!" in channel: channel_non_local_name = self.non_local_name(channel) # Get its redis key - channel_key = self.prefix + channel_non_local_name + channel_key: str = self.prefix + channel_non_local_name # Have we come across the same redis key? if channel_key not in channel_key_to_message: # If not, fill the corresponding dicts @@ -654,13 +662,15 @@ def _group_key(self, group: str) -> bytes: """ return f"{self.prefix}:group:{group}".encode("utf8") - def serialize(self, message) -> bytes: + ### Serialization ### + + def serialize(self, message: typing.Any) -> bytes: """ Serializes message to a byte string. """ return self._serializer.serialize(message) - def deserialize(self, message: bytes): + def deserialize(self, message: bytes) -> typing.Any: """ Deserializes from a byte string. """ @@ -671,7 +681,7 @@ def deserialize(self, message: bytes): def consistent_hash(self, value: typing.Union[str, "Buffer"]) -> int: return _consistent_hash(value, self.ring_size) - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}(hosts={self.hosts})" ### Connection handling ### diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index 0901b3b..ef8d8ea 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -3,6 +3,7 @@ import logging import typing import uuid +from copy import deepcopy from redis import asyncio as aioredis @@ -16,12 +17,14 @@ ) if typing.TYPE_CHECKING: - from redis.asyncio.client import Redis + from redis.asyncio.client import PubSub, Redis from typing_extensions import Buffer logger = logging.getLogger(__name__) -async def _async_proxy(obj: "RedisPubSubChannelLayer", name: str, *args, **kwargs): +async def _async_proxy( + obj: "RedisPubSubChannelLayer", name: str, *args, **kwargs +) -> typing.Any: # Must be defined as a function and not a method due to # https://bugs.python.org/issue38364 layer = obj._get_layer() @@ -37,7 +40,7 @@ def __init__( ] = None, serializer_format="msgpack", **kwargs, - ): + ) -> None: self._args = args self._kwargs = kwargs self._layers: typing.Dict[asyncio.AbstractEventLoop, RedisPubSubLoopLayer] = {} @@ -47,7 +50,7 @@ def __init__( symmetric_encryption_keys=symmetric_encryption_keys, ) - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> typing.Any: if name in ( "new_channel", "send", @@ -61,13 +64,13 @@ def __getattr__(self, name: str): else: return getattr(self._get_layer(), name) - def serialize(self, message) -> bytes: + def serialize(self, message: typing.Any) -> bytes: """ Serializes message to a byte string. """ return self._serializer.serialize(message) - def deserialize(self, message: bytes): + def deserialize(self, message: bytes) -> typing.Any: """ Deserializes from a byte string. """ @@ -79,10 +82,11 @@ def _get_layer(self) -> "RedisPubSubLoopLayer": try: layer = self._layers[loop] except KeyError: + kwargs = deepcopy(self._kwargs) + kwargs["channel_layer"] = self layer = RedisPubSubLoopLayer( *self._args, - **self._kwargs, - channel_layer=self, + **kwargs, ) self._layers[loop] = layer _wrap_close(self, loop) @@ -97,13 +101,13 @@ class RedisPubSubLoopLayer: def __init__( self, - hosts: typing.Union[typing.Iterable, str, bytes, None] = None, + hosts: typing.Union[typing.Iterable[typing.Any], str, bytes, None] = None, prefix: str = "asgi", on_disconnect=None, on_reconnect=None, channel_layer: typing.Optional[RedisPubSubChannelLayer] = None, **kwargs, - ): + ) -> None: self.prefix = prefix self.on_disconnect = on_disconnect @@ -112,11 +116,11 @@ def __init__( # Each consumer gets its own *specific* channel, created with the `new_channel()` method. # This dict maps `channel_name` to a queue of messages for that channel. - self.channels: typing.Dict[typing.Any, asyncio.Queue] = {} + self.channels: typing.Dict[typing.Any, asyncio.Queue[typing.Any]] = {} # A channel can subscribe to zero or more groups. # This dict maps `group_name` to set of channel names who are subscribed to that group. - self.groups = {} + self.groups: typing.Dict[str, typing.Set[typing.Any]] = {} # For each host, we create a `RedisSingleShardConnection` to manage the connection to that host. self._shards = [ @@ -131,7 +135,7 @@ def _get_shard( """ return self._shards[_consistent_hash(channel_or_group_name, len(self._shards))] - def _get_group_channel_name(self, group) -> str: + def _get_group_channel_name(self, group: str) -> str: """ Return the channel name used by a group. Includes '__group__' in the returned @@ -143,7 +147,7 @@ def _get_group_channel_name(self, group) -> str: """ return f"{self.prefix}__group__{group}" - async def _subscribe_to_channel(self, channel): + async def _subscribe_to_channel(self, channel: typing.Union[str, bytes]) -> None: self.channels[channel] = asyncio.Queue() shard = self._get_shard(channel) await shard.subscribe(channel) @@ -154,14 +158,16 @@ async def _subscribe_to_channel(self, channel): # Channel layer API ################################################################################ - async def send(self, channel, message): + async def send( + self, channel: typing.Union[str, bytes], message: typing.Any + ) -> None: """ Send a message onto a (general or specific) channel. """ shard = self._get_shard(channel) - await shard.publish(channel, self.channel_layer.serialize(message)) + await shard.publish(channel, self.channel_layer.serialize(message)) # type: ignore[union-attr] - async def new_channel(self, prefix="specific."): + async def new_channel(self, prefix: str = "specific.") -> str: """ Returns a new channel name that can be used by a consumer in our process as a specific channel. @@ -170,7 +176,7 @@ async def new_channel(self, prefix="specific."): await self._subscribe_to_channel(channel) return channel - async def receive(self, channel): + async def receive(self, channel: typing.Union[str, bytes]) -> typing.Any: """ Receive the first message that arrives on the channel. If more than one coroutine waits on the same channel, a random one @@ -201,13 +207,13 @@ async def receive(self, channel): # We don't re-raise here because we want the CancelledError to be the one re-raised. raise - return self.channel_layer.deserialize(message) + return self.channel_layer.deserialize(message) # type: ignore[union-attr] ################################################################################ # Groups extension ################################################################################ - async def group_add(self, group, channel): + async def group_add(self, group: str, channel: typing.Union[str, bytes]) -> None: """ Adds the channel name to a group. """ @@ -226,7 +232,9 @@ async def group_add(self, group, channel): shard = self._get_shard(group_channel) await shard.subscribe(group_channel) - async def group_discard(self, group, channel): + async def group_discard( + self, group: str, channel: typing.Union[str, bytes] + ) -> None: """ Removes the channel from a group if it is in the group; does nothing otherwise (does not error) @@ -242,19 +250,19 @@ async def group_discard(self, group, channel): shard = self._get_shard(group_channel) await shard.unsubscribe(group_channel) - async def group_send(self, group, message): + async def group_send(self, group: str, message: typing.Any) -> None: """ Send the message to all subscribers of the group. """ group_channel = self._get_group_channel_name(group) shard = self._get_shard(group_channel) - await shard.publish(group_channel, self.channel_layer.serialize(message)) + await shard.publish(group_channel, self.channel_layer.serialize(message)) # type: ignore[union-attr] ################################################################################ # Flush extension ################################################################################ - async def flush(self): + async def flush(self) -> None: """ Flush the layer, making it like new. It can continue to be used as if it was just created. This also closes connections, serving as a clean-up @@ -269,41 +277,41 @@ async def flush(self): class RedisSingleShardConnection: def __init__( self, host: typing.Dict[str, typing.Any], channel_layer: RedisPubSubLoopLayer - ): + ) -> None: self.host = host self.channel_layer = channel_layer - self._subscribed_to = set() + self._subscribed_to: typing.Set[typing.Union[str, bytes]] = set() self._lock = asyncio.Lock() self._redis: typing.Optional["Redis"] = None - self._pubsub = None - self._receive_task: typing.Optional[asyncio.Task] = None + self._pubsub: typing.Optional["PubSub"] = None + self._receive_task: typing.Optional["asyncio.Task[typing.Any]"] = None async def publish( self, channel: typing.Union[str, bytes], message: typing.Union[str, bytes, int, float], - ): + ) -> None: async with self._lock: self._ensure_redis() - await self._redis.publish(channel, message) + await self._redis.publish(channel, message) # type: ignore[union-attr] - async def subscribe(self, channel): + async def subscribe(self, channel: typing.Union[str, bytes]) -> None: async with self._lock: if channel not in self._subscribed_to: self._ensure_redis() self._ensure_receiver() - await self._pubsub.subscribe(channel) + await self._pubsub.subscribe(channel) # type: ignore[union-attr] self._subscribed_to.add(channel) - async def unsubscribe(self, channel): + async def unsubscribe(self, channel: typing.Union[str, bytes]) -> None: async with self._lock: if channel in self._subscribed_to: self._ensure_redis() self._ensure_receiver() - await self._pubsub.unsubscribe(channel) + await self._pubsub.unsubscribe(channel) # type: ignore[union-attr] self._subscribed_to.remove(channel) - async def flush(self): + async def flush(self) -> None: async with self._lock: if self._receive_task is not None: self._receive_task.cancel() @@ -321,7 +329,7 @@ async def flush(self): self._pubsub = None self._subscribed_to = set() - async def _do_receiving(self): + async def _do_receiving(self) -> None: while True: try: if self._pubsub and self._pubsub.subscribed: @@ -341,7 +349,9 @@ async def _do_receiving(self): logger.exception("Unexpected exception in receive task") await asyncio.sleep(1) - def _receive_message(self, message: typing.Optional[typing.Dict]): + def _receive_message( + self, message: typing.Optional[typing.Dict[typing.Any, typing.Any]] + ) -> None: if message is not None: name = message["channel"] data = message["data"] @@ -354,12 +364,12 @@ def _receive_message(self, message: typing.Optional[typing.Dict]): if channel_name in self.channel_layer.channels: self.channel_layer.channels[channel_name].put_nowait(data) - def _ensure_redis(self): + def _ensure_redis(self) -> None: if self._redis is None: pool = create_pool(self.host) self._redis = aioredis.Redis(connection_pool=pool) self._pubsub = self._redis.pubsub() - def _ensure_receiver(self): + def _ensure_receiver(self) -> None: if self._receive_task is None: self._receive_task = asyncio.ensure_future(self._do_receiving()) diff --git a/channels_redis/py.typed b/channels_redis/py.typed index 0519ecb..e69de29 100644 --- a/channels_redis/py.typed +++ b/channels_redis/py.typed @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/channels_redis/serializers.py b/channels_redis/serializers.py index a813c3c..e7805bb 100644 --- a/channels_redis/serializers.py +++ b/channels_redis/serializers.py @@ -11,10 +11,11 @@ try: from cryptography.fernet import Fernet, MultiFernet - _MultiFernet = MultiFernet - _Fernet = Fernet + _MultiFernet: typing.Optional[typing.Type[MultiFernet]] = MultiFernet + _Fernet: typing.Optional[typing.Type[Fernet]] = Fernet except ImportError: - _MultiFernet = _Fernet = None + _MultiFernet = None + _Fernet = None class SerializerDoesNotExist(KeyError): @@ -41,25 +42,21 @@ def _setup_encryption( symmetric_encryption_keys: typing.Optional[ typing.Union[typing.Iterable[typing.Union[str, "Buffer"]], str, bytes] ], - ): + ) -> None: # See if we can do encryption if they asked - if not symmetric_encryption_keys: + if symmetric_encryption_keys: + if isinstance(symmetric_encryption_keys, (str, bytes)): + raise ValueError( + "symmetric_encryption_keys must be a list of possible keys" + ) + if _MultiFernet is None: + raise ValueError( + "Cannot run with encryption without 'cryptography' installed." + ) + sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys] + self.crypter = _MultiFernet(sub_fernets) + else: self.crypter = None - return - - if isinstance(symmetric_encryption_keys, (str, bytes)): - raise ValueError( - "symmetric_encryption_keys must be a list of possible keys" - ) - keys = typing.cast( - typing.Iterable[typing.Union[str, "Buffer"]], symmetric_encryption_keys - ) - if _MultiFernet is None: - raise ValueError( - "Cannot run with encryption without 'cryptography' installed." - ) - sub_fernets = [self.make_fernet(key) for key in keys] - self.crypter = _MultiFernet(sub_fernets) def make_fernet(self, key: typing.Union[str, "Buffer"]) -> "Fernet": """ @@ -76,32 +73,32 @@ def make_fernet(self, key: typing.Union[str, "Buffer"]) -> "Fernet": return _Fernet(formatted_key) @abc.abstractmethod - def as_bytes(self, message, *args, **kwargs): + def as_bytes(self, message: typing.Any, *args, **kwargs) -> bytes: raise NotImplementedError @abc.abstractmethod - def from_bytes(self, message, *args, **kwargs): + def from_bytes(self, message: bytes, *args, **kwargs) -> typing.Any: raise NotImplementedError - def serialize(self, message) -> bytes: + def serialize(self, message: typing.Any) -> bytes: """ Serializes message to a byte string. """ - message = self.as_bytes(message) + msg = self.as_bytes(message) if self.crypter: - message = self.crypter.encrypt(message) + msg = self.crypter.encrypt(msg) if self.random_prefix_length > 0: # provide random prefix - message = ( + msg = ( random.getrandbits(8 * self.random_prefix_length).to_bytes( self.random_prefix_length, "big" ) - + message + + msg ) - return message + return msg - def deserialize(self, message: bytes): + def deserialize(self, message: bytes) -> typing.Any: """ Deserializes from a byte string. """ @@ -116,10 +113,10 @@ def deserialize(self, message: bytes): class MissingSerializer(BaseMessageSerializer): - exception: "Exception" = Exception(None) + exception: typing.Optional[Exception] = None - def __init__(self, *args, **kwargs): - raise self.exception + def __init__(self, *args, **kwargs) -> None: + raise self.exception if self.exception else NotImplementedError() class JSONSerializer(BaseMessageSerializer): @@ -127,17 +124,19 @@ class JSONSerializer(BaseMessageSerializer): # thus we must force bytes conversion # we use UTF-8 since it is the recommended encoding for interoperability # see https://docs.python.org/3/library/json.html#character-encodings - def as_bytes(self, message, *args, **kwargs) -> bytes: - message = json.dumps(message, *args, **kwargs) - return message.encode("utf-8") + def as_bytes(self, message: typing.Any, *args, **kwargs) -> bytes: + msg = json.dumps(message, *args, **kwargs) + return msg.encode("utf-8") - from_bytes = staticmethod(json.loads) + from_bytes = staticmethod(json.loads) # type: ignore[assignment] # code ready for a future in which msgpack may become an optional dependency -MsgPackSerializer: typing.Optional[typing.Type[BaseMessageSerializer]] = None +MsgPackSerializer: typing.Union[ + typing.Type[BaseMessageSerializer], typing.Type[MissingSerializer] +] try: - import msgpack + import msgpack # type: ignore[import-untyped] except ImportError as exc: class _MsgPackSerializer(MissingSerializer): @@ -147,8 +146,8 @@ class _MsgPackSerializer(MissingSerializer): else: class __MsgPackSerializer(BaseMessageSerializer): - as_bytes = staticmethod(msgpack.packb) - from_bytes = staticmethod(msgpack.unpackb) + as_bytes = staticmethod(msgpack.packb) # type: ignore[assignment] + from_bytes = staticmethod(msgpack.unpackb) # type: ignore[assignment] MsgPackSerializer = __MsgPackSerializer @@ -158,12 +157,12 @@ class SerializersRegistry: Serializers registry inspired by that of ``django.core.serializers``. """ - def __init__(self): + def __init__(self) -> None: self._registry: typing.Dict[typing.Any, typing.Type[BaseMessageSerializer]] = {} def register_serializer( - self, format, serializer_class: typing.Type[BaseMessageSerializer] - ): + self, format: typing.Any, serializer_class: typing.Type[BaseMessageSerializer] + ) -> None: """ Register a new serializer for given format """ @@ -180,7 +179,9 @@ def register_serializer( self._registry[format] = serializer_class - def get_serializer(self, format, *args, **kwargs) -> BaseMessageSerializer: + def get_serializer( + self, format: typing.Any, *args, **kwargs + ) -> BaseMessageSerializer: try: serializer_class = self._registry[format] except KeyError: diff --git a/channels_redis/utils.py b/channels_redis/utils.py index 7de5d37..89265f8 100644 --- a/channels_redis/utils.py +++ b/channels_redis/utils.py @@ -35,10 +35,10 @@ def _consistent_hash(value: typing.Union[str, "Buffer"], ring_size: int) -> int: def _wrap_close( proxy: typing.Union["RedisPubSubChannelLayer", "RedisChannelLayer"], loop: "AbstractEventLoop", -): +) -> None: original_impl = loop.close - def _wrapper(self, *args, **kwargs): + def _wrapper(self, *args, **kwargs) -> typing.Any: if loop in proxy._layers: layer = proxy._layers[loop] del proxy._layers[loop] @@ -47,10 +47,10 @@ def _wrapper(self, *args, **kwargs): self.close = original_impl return self.close(*args, **kwargs) - loop.close = types.MethodType(_wrapper, loop) + loop.close = types.MethodType(_wrapper, loop) # type: ignore[method-assign] -async def _close_redis(connection: "Redis"): +async def _close_redis(connection: "Redis") -> None: """ Handle compatibility with redis-py 4.x and 5.x close methods """ @@ -61,8 +61,8 @@ async def _close_redis(connection: "Redis"): def decode_hosts( - hosts: typing.Union[typing.Iterable, str, bytes, None], -) -> typing.List[typing.Dict]: + hosts: typing.Union[typing.Iterable[typing.Any], str, bytes, None], +) -> typing.List[typing.Dict[typing.Any, typing.Any]]: """ Takes the value of the "hosts" argument and returns a list of kwargs to use for the Redis connection constructor. @@ -77,7 +77,7 @@ def decode_hosts( ) # Decode each hosts entry into a kwargs dict - result: typing.List[typing.Dict] = [] + result: typing.List[typing.Dict[typing.Any, typing.Any]] = [] for entry in hosts: if isinstance(entry, dict): result.append(entry) diff --git a/conftest.py b/conftest.py deleted file mode 100644 index bae5ac4..0000000 --- a/conftest.py +++ /dev/null @@ -1,10 +0,0 @@ -import asyncio -import pytest - - -@pytest.fixture -def event_loop(request): - """Create an instance of the default event loop for each test case.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() diff --git a/setup.cfg b/setup.cfg index 3888deb..b5a4092 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,3 +12,10 @@ max-line-length = 119 [isort] profile = black known_first_party = channels, asgiref, channels_redis, daphne + +[mypy] +warn_unused_ignores = True +strict = True + +[mypy-tests.*] +ignore_errors = True diff --git a/tests/test_core.py b/tests/test_core.py index 92f6923..e5bda1c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -419,6 +419,12 @@ def test_repeated_group_send_with_async_to_sync(channel_layer): pytest.fail(f"repeated async_to_sync wrapped group_send calls raised {exc}") +@pytest.mark.xfail( + reason=""" +Fails with error in redis-py: int() argument must be a string, a bytes-like +object or a real number, not 'NoneType'. Refs: #348 +""" +) @pytest.mark.asyncio async def test_receive_cancel(channel_layer): """ @@ -551,6 +557,7 @@ async def test_message_expiry__group_send(channel_layer): await channel_layer.receive(channel_name) +@pytest.mark.xfail(reason="Fails with timeout. Refs: #348") @pytest.mark.asyncio async def test_message_expiry__group_send__one_channel_expires_message(channel_layer): expiry = 3 diff --git a/tox.ini b/tox.ini index 7417992..c855db0 100644 --- a/tox.ini +++ b/tox.ini @@ -9,7 +9,9 @@ usedevelop = true extras = tests commands = pytest -v {posargs} + mypy --config-file=tox.ini channels_redis deps = + mypy>=1.9.0 ch30: channels>=3.0,<3.1 ch40: channels>=4.0,<4.1 chmain: https://github.com/django/channels/archive/main.tar.gz @@ -27,3 +29,7 @@ commands = flake8 channels_redis tests black --check channels_redis tests isort --check-only --diff channels_redis tests + + +[mypy] +disallow_untyped_defs = False From 0052136ff6097ddede1cb8a373a2aec939008676 Mon Sep 17 00:00:00 2001 From: iwakitakuma Date: Sat, 26 Jul 2025 23:36:56 +0900 Subject: [PATCH 5/5] CHORE: mypy deps --- tox.ini | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 83fe977..c3632a3 100644 --- a/tox.ini +++ b/tox.ini @@ -17,7 +17,8 @@ deps = redis5: redis>=5.0,<6 redis6: redis>=6.0,<7 redismain: https://github.com/redis/redis-py/archive/master.tar.gz - + mypy>=1.9.0 + [testenv:qa] skip_install=true deps =