Skip to content

Commit

Permalink
InMemoryChannelLayer improvements, test fixes (#1976)
Browse files Browse the repository at this point in the history
  • Loading branch information
devkral authored Jul 30, 2024
1 parent e39fe13 commit e533186
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 31 deletions.
63 changes: 35 additions & 28 deletions channels/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ def __init__(
group_expiry=86400,
capacity=100,
channel_capacity=None,
**kwargs
**kwargs,
):
super().__init__(
expiry=expiry,
capacity=capacity,
channel_capacity=channel_capacity,
**kwargs
**kwargs,
)
self.channels = {}
self.groups = {}
Expand All @@ -225,13 +225,14 @@ async def send(self, channel, message):
# name in message
assert "__asgi_channel__" not in message

queue = self.channels.setdefault(channel, asyncio.Queue())
# Are we full
if queue.qsize() >= self.capacity:
raise ChannelFull(channel)

queue = self.channels.setdefault(
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
)
# Add message
await queue.put((time.time() + self.expiry, deepcopy(message)))
try:
queue.put_nowait((time.time() + self.expiry, deepcopy(message)))
except asyncio.queues.QueueFull:
raise ChannelFull(channel)

async def receive(self, channel):
"""
Expand All @@ -242,14 +243,16 @@ async def receive(self, channel):
assert self.valid_channel_name(channel)
self._clean_expired()

queue = self.channels.setdefault(channel, asyncio.Queue())
queue = self.channels.setdefault(
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
)

# Do a plain direct receive
try:
_, message = await queue.get()
finally:
if queue.empty():
del self.channels[channel]
self.channels.pop(channel, None)

return message

Expand Down Expand Up @@ -279,19 +282,17 @@ def _clean_expired(self):
self._remove_from_groups(channel)
# Is the channel now empty and needs deleting?
if queue.empty():
del self.channels[channel]
self.channels.pop(channel, None)

# Group Expiration
timeout = int(time.time()) - self.group_expiry
for group in self.groups:
for channel in list(self.groups.get(group, set())):
# If join time is older than group_expiry end the group membership
if (
self.groups[group][channel]
and int(self.groups[group][channel]) < timeout
):
for channels in self.groups.values():
for name, timestamp in list(channels.items()):
# If join time is older than group_expiry
# end the group membership
if timestamp and timestamp < timeout:
# Delete from group
del self.groups[group][channel]
channels.pop(name, None)

# Flush extension

Expand All @@ -308,8 +309,7 @@ def _remove_from_groups(self, channel):
Removes a channel from all groups. Used when a message on it expires.
"""
for channels in self.groups.values():
if channel in channels:
del channels[channel]
channels.pop(channel, None)

# Groups extension

Expand All @@ -329,22 +329,29 @@ async def group_discard(self, group, channel):
assert self.valid_channel_name(channel), "Invalid channel name"
assert self.valid_group_name(group), "Invalid group name"
# Remove from group set
if group in self.groups:
if channel in self.groups[group]:
del self.groups[group][channel]
if not self.groups[group]:
del self.groups[group]
group_channels = self.groups.get(group, None)
if group_channels:
# remove channel if in group
group_channels.pop(channel, None)
# is group now empty? If yes remove it
if not group_channels:
self.groups.pop(group, None)

async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
assert self.valid_group_name(group), "Invalid group name"
# Run clean
self._clean_expired()

# Send to each channel
for channel in self.groups.get(group, set()):
ops = []
if group in self.groups:
for channel in self.groups[group].keys():
ops.append(asyncio.create_task(self.send(channel, message)))
for send_result in asyncio.as_completed(ops):
try:
await self.send(channel, message)
await send_result
except ChannelFull:
pass

Expand Down
30 changes: 27 additions & 3 deletions tests/test_inmemorychannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,36 @@ async def test_send_receive(channel_layer):
await channel_layer.send(
"test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"}
)
await channel_layer.send(
"test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"}
)
message = await channel_layer.receive("test-channel-1")
assert message["type"] == "test.message"
assert message["text"] == "Ahoy-hoy!"
# not removed because not empty
assert "test-channel-1" in channel_layer.channels
message = await channel_layer.receive("test-channel-1")
assert message["type"] == "test.message"
assert message["text"] == "Ahoy-hoy!"
# removed because empty
assert "test-channel-1" not in channel_layer.channels


@pytest.mark.asyncio
async def test_race_empty(channel_layer):
"""
Makes sure the race is handled gracefully.
"""
receive_task = asyncio.create_task(channel_layer.receive("test-channel-1"))
await asyncio.sleep(0.1)
await channel_layer.send(
"test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"}
)
del channel_layer.channels["test-channel-1"]
await asyncio.sleep(0.1)
message = await receive_task
assert message["type"] == "test.message"
assert message["text"] == "Ahoy-hoy!"


@pytest.mark.asyncio
Expand Down Expand Up @@ -62,7 +89,6 @@ async def test_multi_send_receive(channel_layer):
"""
Tests overlapping sends and receives, and ordering.
"""
channel_layer = InMemoryChannelLayer()
await channel_layer.send("test-channel-3", {"type": "message.1"})
await channel_layer.send("test-channel-3", {"type": "message.2"})
await channel_layer.send("test-channel-3", {"type": "message.3"})
Expand All @@ -76,7 +102,6 @@ async def test_groups_basic(channel_layer):
"""
Tests basic group operation.
"""
channel_layer = InMemoryChannelLayer()
await channel_layer.group_add("test-group", "test-gr-chan-1")
await channel_layer.group_add("test-group", "test-gr-chan-2")
await channel_layer.group_add("test-group", "test-gr-chan-3")
Expand All @@ -97,7 +122,6 @@ async def test_groups_channel_full(channel_layer):
"""
Tests that group_send ignores ChannelFull
"""
channel_layer = InMemoryChannelLayer()
await channel_layer.group_add("test-group", "test-gr-chan-1")
await channel_layer.group_send("test-group", {"type": "message.1"})
await channel_layer.group_send("test-group", {"type": "message.1"})
Expand Down

0 comments on commit e533186

Please sign in to comment.