From 90411459a6099158ddbbc1a9f26951364bdfecac Mon Sep 17 00:00:00 2001 From: John Belmonte Date: Mon, 26 Oct 2020 11:11:21 +0900 Subject: [PATCH 1/3] open_memory_channel(): return a named tuple partially addresses #719 --- docs/source/reference-core.rst | 6 ++++++ newsfragments/1771.feature.rst | 5 +++++ trio/_channel.py | 17 ++++++++++++----- trio/tests/test_channel.py | 5 +++++ trio/tests/test_highlevel_serve_listeners.py | 4 ++-- 5 files changed, 30 insertions(+), 7 deletions(-) create mode 100644 newsfragments/1771.feature.rst diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 664a8d96c2..b67bf52e78 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1124,6 +1124,12 @@ inside a single process, and for that you can use .. autofunction:: open_memory_channel(max_buffer_size) +Assigning the send and receive channels to separate variables usually +produces the most readable code. However, in situations where the pair +is preserved-- such as a collection of memory channels-- prefer named tuple +access (``pair.send_channel``, ``pair.receive_channel``) over indexed access +(``pair[0]``, ``pair[1]``). + .. note:: If you've used the :mod:`threading` or :mod:`asyncio` modules, you may be familiar with :class:`queue.Queue` or :class:`asyncio.Queue`. In Trio, :func:`open_memory_channel` is diff --git a/newsfragments/1771.feature.rst b/newsfragments/1771.feature.rst new file mode 100644 index 0000000000..ec4b528aa0 --- /dev/null +++ b/newsfragments/1771.feature.rst @@ -0,0 +1,5 @@ +open_memory_channel() now returns a named tuple with attributes ``send_channel`` +and ``receive_channel``. This can be used to avoid indexed access of the +channel halves in some scenarios such as a collection of channels. (Note: when +dealing with a single memory channel, assigning the send and receive halves +to separate variables via destructuring is still considered more readable.) diff --git a/trio/_channel.py b/trio/_channel.py index dac7935c0c..417e530024 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,16 +1,24 @@ from collections import deque, OrderedDict from math import inf +from typing import NamedTuple import attr from outcome import Error, Value -from .abc import SendChannel, ReceiveChannel, Channel +from .abc import SendChannel, ReceiveChannel from ._util import generic_function, NoPublicConstructor import trio from ._core import enable_ki_protection +class MemoryChannelPair(NamedTuple): + """Named tuple of send/receive memory channels""" + + send_channel: "MemorySendChannel" + receive_channel: "MemoryReceiveChannel" + + @generic_function def open_memory_channel(max_buffer_size): """Open a channel for passing objects between tasks within a process. @@ -40,9 +48,8 @@ def open_memory_channel(max_buffer_size): see :ref:`channel-buffering` for more details. If in doubt, use 0. Returns: - A pair ``(send_channel, receive_channel)``. If you have - trouble remembering which order these go in, remember: data - flows from left → right. + A named tuple ``(send_channel, receive_channel)``. The tuple ordering is + intended to match the image of data flowing from left → right. In addition to the standard channel methods, all memory channel objects provide a ``statistics()`` method, which returns an object with the @@ -69,7 +76,7 @@ def open_memory_channel(max_buffer_size): if max_buffer_size < 0: raise ValueError("max_buffer_size must be >= 0") state = MemoryChannelState(max_buffer_size) - return ( + return MemoryChannelPair( MemorySendChannel._create(state), MemoryReceiveChannel._create(state), ) diff --git a/trio/tests/test_channel.py b/trio/tests/test_channel.py index b43466dd7d..83fd746bdc 100644 --- a/trio/tests/test_channel.py +++ b/trio/tests/test_channel.py @@ -350,3 +350,8 @@ async def do_send(s, v): assert await r.receive() == 1 with pytest.raises(trio.WouldBlock): r.receive_nowait() + + +def test_named_tuple(): + pair = open_memory_channel(0) + assert pair.send_channel, pair.receive_channel == pair diff --git a/trio/tests/test_highlevel_serve_listeners.py b/trio/tests/test_highlevel_serve_listeners.py index b028092eb9..7925a16ff4 100644 --- a/trio/tests/test_highlevel_serve_listeners.py +++ b/trio/tests/test_highlevel_serve_listeners.py @@ -19,7 +19,7 @@ class MemoryListener(trio.abc.Listener): async def connect(self): assert not self.closed client, server = memory_stream_pair() - await self.queued_streams[0].send(server) + await self.queued_streams.send_channel.send(server) return client async def accept(self): @@ -27,7 +27,7 @@ async def accept(self): assert not self.closed if self.accept_hook is not None: await self.accept_hook() - stream = await self.queued_streams[1].receive() + stream = await self.queued_streams.receive_channel.receive() self.accepted_streams.append(stream) return stream From 44a06800c2f4da3253da410549729401dc3b9aa2 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Fri, 8 Sep 2023 11:52:23 +1000 Subject: [PATCH 2/3] Manually define MemoryChannelPair --- trio/_channel.py | 91 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 82 insertions(+), 9 deletions(-) diff --git a/trio/_channel.py b/trio/_channel.py index 2fb2b9701d..67dc731cb5 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,10 +1,12 @@ from __future__ import annotations from collections import OrderedDict, deque +from collections.abc import Iterable from math import inf +from operator import itemgetter from types import TracebackType from typing import Tuple # only needed for typechecking on <3.9 -from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar import attr from outcome import Error, Value @@ -22,13 +24,6 @@ T = TypeVar("T") -class MemoryChannelPair(NamedTuple, Generic[T]): - """Named tuple of send/receive memory channels""" - - send_channel: MemorySendChannel[T] - receive_channel: MemoryReceiveChannel[T] - - def _open_memory_channel( max_buffer_size: int | float, ) -> MemoryChannelPair[T]: @@ -99,7 +94,7 @@ def _open_memory_channel( if TYPE_CHECKING: # written as a class so that you can say open_memory_channel[int](5) # Need to use Tuple instead of tuple due to CI check running on 3.8 - class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]): + class open_memory_channel(MemoryChannelPair[T]): def __new__( # type: ignore[misc] # "must return a subtype" cls, max_buffer_size: int | float ) -> MemoryChannelPair[T]: @@ -444,3 +439,81 @@ def close(self) -> None: async def aclose(self) -> None: self.close() await trio.lowlevel.checkpoint() + + +# We cannot use generic named tuples before Py 3.11, manually define it. +class MemoryChannelPair( + Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]], + Generic[T], +): + """Named tuple of send/receive memory channels.""" + + __slots__ = () + _fields = ("send_channel", "receive_channel") + + if TYPE_CHECKING: + + @property + def send_channel(self) -> MemorySendChannel[T]: + """Returns the sending channel half.""" + return self[0] + + @property + def receive_channel(self) -> MemoryReceiveChannel[T]: + """Returns the receiving channel half.""" + return self[1] + + else: # More efficient + send_channel = property(itemgetter(0), doc="Returns the sending channel half.") + receive_channel = property( + itemgetter(1), doc="Returns the receiving channel half." + ) + + def __new__( + cls, + send_channel: MemorySendChannel[T], + receive_channel: MemoryReceiveChannel[T], + ) -> Self: + """Create new instance of MemoryChannelPair(send_channel, receive_channel)""" + return tuple.__new__(cls, (send_channel, receive_channel)) # type: ignore[type-var] + + @classmethod + def _make( + cls, + iterable: Iterable[MemorySendChannel[T] | MemoryReceiveChannel[T]], + ) -> Self: + """Make a new MemoryChannelPair object from a sequence or iterable""" + send, rec = iterable + if isinstance(send, MemoryReceiveChannel) or isinstance(rec, MemorySendChannel): + raise TypeError("Channel order passed incorrectly.") + return tuple.__new__(cls, (send, rec)) # type: ignore[type-var] + + def _replace( + self, + *, + send_channel: MemorySendChannel[T] | None = None, + receive_channel: MemoryReceiveChannel[T] | None = None, + ) -> MemoryChannelPair[T]: + """Return a new MemoryChannelPair object replacing specified fields with new values""" + if send_channel is None: + send_channel = self.send_channel + if receive_channel is None: + receive_channel = self.receive_channel + return tuple.__new__( + MemoryChannelPair, + (send_channel, receive_channel), + ) # type: ignore[type-var] + + def __repr__(self) -> str: + """Return a nicely formatted representation string""" + return f"{self.__class__.__name__}(send_channel={self[0]!r}, receive_channel={self[1]!r})" + + def _asdict( + self, + ) -> OrderedDict[str, MemorySendChannel[T] | MemoryReceiveChannel[T]]: + """Return a new OrderedDict which maps field names to their values.""" + return OrderedDict(zip(self._fields, self)) + + def __getnewargs__(self) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: + """Return self as a plain tuple. Used by copy and pickle.""" + return (self[0], self[1]) From bd3d1b36d3add978faa4dfe194340c71d8b71a01 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 1 Aug 2024 00:23:34 -0500 Subject: [PATCH 3/3] Move `open_memory_channel` definition to after `MemoryChannelPair` definition --- src/trio/_channel.py | 47 ++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/src/trio/_channel.py b/src/trio/_channel.py index 3d8445bc59..1f328daea0 100644 --- a/src/trio/_channel.py +++ b/src/trio/_channel.py @@ -1,15 +1,13 @@ from __future__ import annotations from collections import OrderedDict, deque -from collections.abc import Iterable from math import inf from operator import itemgetter - from typing import ( TYPE_CHECKING, Generic, - TypeVar, Tuple, # only needed for typechecking on <3.9 + TypeVar, ) import attrs @@ -22,6 +20,7 @@ from ._util import NoPublicConstructor, final, generic_function if TYPE_CHECKING: + from collections.abc import Iterable from types import TracebackType from typing_extensions import Self @@ -94,27 +93,6 @@ def _open_memory_channel( ) -# This workaround requires python3.9+, once older python versions are not supported -# or there's a better way of achieving type-checking on a generic factory function, -# it could replace the normal function header -if TYPE_CHECKING: - # written as a class so that you can say open_memory_channel[int](5) - # Need to use Tuple instead of tuple due to CI check running on 3.8 - class open_memory_channel(MemoryChannelPair[T]): - def __new__( # type: ignore[misc] # "must return a subtype" - cls, max_buffer_size: int | float # noqa: PYI041 - ) -> MemoryChannelPair[T]: - return _open_memory_channel(max_buffer_size) - - def __init__(self, max_buffer_size: int | float): # noqa: PYI041 - ... - -else: - # apply the generic_function decorator to make open_memory_channel indexable - # so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime - open_memory_channel = generic_function(_open_memory_channel) - - @attrs.frozen class MemoryChannelStats: current_buffer_used: int @@ -522,3 +500,24 @@ def _asdict( def __getnewargs__(self) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: """Return self as a plain tuple. Used by copy and pickle.""" return (self[0], self[1]) + + +# This workaround requires python3.9+, once older python versions are not supported +# or there's a better way of achieving type-checking on a generic factory function, +# it could replace the normal function header +if TYPE_CHECKING: + # written as a class so that you can say open_memory_channel[int](5) + # Need to use Tuple instead of tuple due to CI check running on 3.8 + class open_memory_channel(MemoryChannelPair[T]): + def __new__( # type: ignore[misc] # "must return a subtype" + cls, max_buffer_size: int | float # noqa: PYI041 + ) -> MemoryChannelPair[T]: + return _open_memory_channel(max_buffer_size) + + def __init__(self, max_buffer_size: int | float): # noqa: PYI041 + ... + +else: + # apply the generic_function decorator to make open_memory_channel indexable + # so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime + open_memory_channel = generic_function(_open_memory_channel)