asm
This commit is contained in:
264
asm/venv/lib/python3.11/site-packages/zmq/tests/__init__.py
Normal file
264
asm/venv/lib/python3.11/site-packages/zmq/tests/__init__.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# Copyright (c) PyZMQ Developers.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
from typing import List
|
||||
from unittest import SkipTest, TestCase
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.utils import jsonapi
|
||||
|
||||
try:
|
||||
import gevent
|
||||
|
||||
from zmq import green as gzmq
|
||||
|
||||
have_gevent = True
|
||||
except ImportError:
|
||||
have_gevent = False
|
||||
|
||||
|
||||
PYPY = platform.python_implementation() == 'PyPy'
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# skip decorators (directly from unittest)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_id = lambda x: x
|
||||
|
||||
skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
|
||||
require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base test class
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def term_context(ctx, timeout):
|
||||
"""Terminate a context with a timeout"""
|
||||
t = Thread(target=ctx.term)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
t.join(timeout=timeout)
|
||||
if t.is_alive():
|
||||
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
||||
zmq.sugar.context.Context._instance = None
|
||||
raise RuntimeError(
|
||||
"context could not terminate, open sockets likely remain in test"
|
||||
)
|
||||
|
||||
|
||||
class BaseZMQTestCase(TestCase):
|
||||
green = False
|
||||
teardown_timeout = 10
|
||||
test_timeout_seconds = int(os.environ.get("ZMQ_TEST_TIMEOUT") or 60)
|
||||
sockets: List[zmq.Socket]
|
||||
|
||||
@property
|
||||
def _is_pyzmq_test(self):
|
||||
return self.__class__.__module__.split(".", 1)[0] == __name__.split(".", 1)[0]
|
||||
|
||||
@property
|
||||
def _should_test_timeout(self):
|
||||
return (
|
||||
self._is_pyzmq_test
|
||||
and hasattr(signal, 'SIGALRM')
|
||||
and self.test_timeout_seconds
|
||||
)
|
||||
|
||||
@property
|
||||
def Context(self):
|
||||
if self.green:
|
||||
return gzmq.Context
|
||||
else:
|
||||
return zmq.Context
|
||||
|
||||
def socket(self, socket_type):
|
||||
s = self.context.socket(socket_type)
|
||||
self.sockets.append(s)
|
||||
return s
|
||||
|
||||
def _alarm_timeout(self, timeout, *args):
|
||||
raise TimeoutError(f"Test did not complete in {timeout} seconds")
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not self._is_pyzmq_test:
|
||||
warnings.warn(
|
||||
"zmq.tests.BaseZMQTestCase is deprecated in pyzmq 25, we recommend managing your own contexts and sockets.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
if self.green and not have_gevent:
|
||||
raise SkipTest("requires gevent")
|
||||
|
||||
self.context = self.Context.instance()
|
||||
self.sockets = []
|
||||
if self._should_test_timeout:
|
||||
# use SIGALRM to avoid test hangs
|
||||
signal.signal(
|
||||
signal.SIGALRM, partial(self._alarm_timeout, self.test_timeout_seconds)
|
||||
)
|
||||
signal.alarm(self.test_timeout_seconds)
|
||||
|
||||
def tearDown(self):
|
||||
if self._should_test_timeout:
|
||||
# cancel the timeout alarm, if there was one
|
||||
signal.alarm(0)
|
||||
contexts = {self.context}
|
||||
while self.sockets:
|
||||
sock = self.sockets.pop()
|
||||
contexts.add(sock.context) # in case additional contexts are created
|
||||
sock.close(0)
|
||||
for ctx in contexts:
|
||||
try:
|
||||
term_context(ctx, self.teardown_timeout)
|
||||
except Exception:
|
||||
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
||||
zmq.sugar.context.Context._instance = None
|
||||
raise
|
||||
|
||||
super().tearDown()
|
||||
|
||||
def create_bound_pair(
|
||||
self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'
|
||||
):
|
||||
"""Create a bound socket pair using a random port."""
|
||||
s1 = self.context.socket(type1)
|
||||
s1.setsockopt(zmq.LINGER, 0)
|
||||
port = s1.bind_to_random_port(interface)
|
||||
s2 = self.context.socket(type2)
|
||||
s2.setsockopt(zmq.LINGER, 0)
|
||||
s2.connect(f'{interface}:{port}')
|
||||
self.sockets.extend([s1, s2])
|
||||
return s1, s2
|
||||
|
||||
def ping_pong(self, s1, s2, msg):
|
||||
s1.send(msg)
|
||||
msg2 = s2.recv()
|
||||
s2.send(msg2)
|
||||
msg3 = s1.recv()
|
||||
return msg3
|
||||
|
||||
def ping_pong_json(self, s1, s2, o):
|
||||
if jsonapi.jsonmod is None:
|
||||
raise SkipTest("No json library")
|
||||
s1.send_json(o)
|
||||
o2 = s2.recv_json()
|
||||
s2.send_json(o2)
|
||||
o3 = s1.recv_json()
|
||||
return o3
|
||||
|
||||
def ping_pong_pyobj(self, s1, s2, o):
|
||||
s1.send_pyobj(o)
|
||||
o2 = s2.recv_pyobj()
|
||||
s2.send_pyobj(o2)
|
||||
o3 = s1.recv_pyobj()
|
||||
return o3
|
||||
|
||||
def assertRaisesErrno(self, errno, func, *args, **kwargs):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except zmq.ZMQError as e:
|
||||
self.assertEqual(
|
||||
e.errno,
|
||||
errno,
|
||||
"wrong error raised, expected '%s' \
|
||||
got '%s'"
|
||||
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
|
||||
)
|
||||
else:
|
||||
self.fail("Function did not raise any error")
|
||||
|
||||
def _select_recv(self, multipart, socket, **kwargs):
|
||||
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
|
||||
if zmq.zmq_version_info() >= (3, 1, 0):
|
||||
# zmq 3.1 has a bug, where poll can return false positives,
|
||||
# so we wait a little bit just in case
|
||||
# See LIBZMQ-280 on JIRA
|
||||
time.sleep(0.1)
|
||||
|
||||
r, w, x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
|
||||
assert len(r) > 0, "Should have received a message"
|
||||
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
|
||||
|
||||
recv = socket.recv_multipart if multipart else socket.recv
|
||||
return recv(**kwargs)
|
||||
|
||||
def recv(self, socket, **kwargs):
|
||||
"""call recv in a way that raises if there is nothing to receive"""
|
||||
return self._select_recv(False, socket, **kwargs)
|
||||
|
||||
def recv_multipart(self, socket, **kwargs):
|
||||
"""call recv_multipart in a way that raises if there is nothing to receive"""
|
||||
return self._select_recv(True, socket, **kwargs)
|
||||
|
||||
|
||||
class PollZMQTestCase(BaseZMQTestCase):
|
||||
pass
|
||||
|
||||
|
||||
class GreenTest:
|
||||
"""Mixin for making green versions of test classes"""
|
||||
|
||||
green = True
|
||||
teardown_timeout = 10
|
||||
|
||||
def assertRaisesErrno(self, errno, func, *args, **kwargs):
|
||||
if errno == zmq.EAGAIN:
|
||||
raise SkipTest("Skipping because we're green.")
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except zmq.ZMQError:
|
||||
e = sys.exc_info()[1]
|
||||
self.assertEqual(
|
||||
e.errno,
|
||||
errno,
|
||||
"wrong error raised, expected '%s' \
|
||||
got '%s'"
|
||||
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
|
||||
)
|
||||
else:
|
||||
self.fail("Function did not raise any error")
|
||||
|
||||
def tearDown(self):
|
||||
if self._should_test_timeout:
|
||||
# cancel the timeout alarm, if there was one
|
||||
signal.alarm(0)
|
||||
contexts = {self.context}
|
||||
while self.sockets:
|
||||
sock = self.sockets.pop()
|
||||
contexts.add(sock.context) # in case additional contexts are created
|
||||
sock.close()
|
||||
try:
|
||||
gevent.joinall(
|
||||
[gevent.spawn(ctx.term) for ctx in contexts],
|
||||
timeout=self.teardown_timeout,
|
||||
raise_error=True,
|
||||
)
|
||||
except gevent.Timeout:
|
||||
raise RuntimeError(
|
||||
"context could not terminate, open sockets likely remain in test"
|
||||
)
|
||||
|
||||
def skip_green(self):
|
||||
raise SkipTest("Skipping because we are green")
|
||||
|
||||
|
||||
def skip_green(f):
|
||||
def skipping_test(self, *args, **kwargs):
|
||||
if self.green:
|
||||
raise SkipTest("Skipping because we are green")
|
||||
else:
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return skipping_test
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
215
asm/venv/lib/python3.11/site-packages/zmq/tests/conftest.py
Normal file
215
asm/venv/lib/python3.11/site-packages/zmq/tests/conftest.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""pytest configuration and fixtures"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
|
||||
try:
|
||||
import tornado
|
||||
from tornado import version_info
|
||||
except ImportError:
|
||||
tornado = None
|
||||
else:
|
||||
if version_info < (5,):
|
||||
tornado = None
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
test_timeout_seconds = os.environ.get("ZMQ_TEST_TIMEOUT")
|
||||
teardown_timeout = 10
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
"""This function is automatically run by pytest passing all collected test
|
||||
functions.
|
||||
We use it to add asyncio marker to all async tests and assert we don't use
|
||||
test functions that are async generators which wouldn't make sense.
|
||||
It is no longer required with pytest-asyncio >= 0.17
|
||||
"""
|
||||
for item in items:
|
||||
if inspect.iscoroutinefunction(item.obj):
|
||||
item.add_marker('asyncio')
|
||||
assert not inspect.isasyncgenfunction(item.obj)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def io_loop(event_loop, request):
|
||||
"""Create tornado io_loop on current asyncio event loop"""
|
||||
if tornado is None:
|
||||
pytest.skip()
|
||||
io_loop = IOLoop.current()
|
||||
assert asyncio.get_event_loop() is event_loop
|
||||
assert io_loop.asyncio_loop is event_loop
|
||||
|
||||
def _close():
|
||||
io_loop.close(all_fds=True)
|
||||
|
||||
request.addfinalizer(_close)
|
||||
return io_loop
|
||||
|
||||
|
||||
def term_context(ctx, timeout):
|
||||
"""Terminate a context with a timeout"""
|
||||
t = Thread(target=ctx.term)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
t.join(timeout=timeout)
|
||||
if t.is_alive():
|
||||
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
||||
zmq.sugar.context.Context._instance = None
|
||||
raise RuntimeError(
|
||||
f"context {ctx} could not terminate, open sockets likely remain in test"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
# make sure selectors are cleared
|
||||
assert dict(zmq.asyncio._selectors) == {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sigalrm_timeout():
|
||||
"""Set timeout using SIGALRM
|
||||
|
||||
Avoids infinite hang in context.term for an unclean context,
|
||||
raising an error instead.
|
||||
"""
|
||||
if not hasattr(signal, "SIGALRM") or not test_timeout_seconds:
|
||||
return
|
||||
|
||||
def _alarm_timeout(*args):
|
||||
raise TimeoutError(f"Test did not complete in {test_timeout_seconds} seconds")
|
||||
|
||||
signal.signal(signal.SIGALRM, _alarm_timeout)
|
||||
signal.alarm(test_timeout_seconds)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Context():
|
||||
"""Context class fixture
|
||||
|
||||
Override in modules to specify a different class (e.g. zmq.green)
|
||||
"""
|
||||
return zmq.Context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contexts(sigalrm_timeout):
|
||||
"""Fixture to track contexts used in tests
|
||||
|
||||
For cleanup purposes
|
||||
"""
|
||||
contexts = set()
|
||||
yield contexts
|
||||
for ctx in contexts:
|
||||
try:
|
||||
term_context(ctx, teardown_timeout)
|
||||
except Exception:
|
||||
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
||||
zmq.sugar.context.Context._instance = None
|
||||
raise
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context(Context, contexts):
|
||||
"""Fixture for shared context"""
|
||||
ctx = Context()
|
||||
contexts.add(ctx)
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sockets(contexts):
|
||||
sockets = []
|
||||
yield sockets
|
||||
# ensure any tracked sockets get their contexts cleaned up
|
||||
for socket in sockets:
|
||||
contexts.add(socket.context)
|
||||
|
||||
# close sockets
|
||||
for socket in sockets:
|
||||
socket.close(linger=0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket(context, sockets):
|
||||
"""Fixture to create sockets, while tracking them for cleanup"""
|
||||
|
||||
def new_socket(*args, **kwargs):
|
||||
s = context.socket(*args, **kwargs)
|
||||
sockets.append(s)
|
||||
return s
|
||||
|
||||
return new_socket
|
||||
|
||||
|
||||
def assert_raises_errno(errno):
|
||||
try:
|
||||
yield
|
||||
except zmq.ZMQError as e:
|
||||
assert (
|
||||
e.errno == errno
|
||||
), f"wrong error raised, expected {zmq.ZMQError(errno)} got {zmq.ZMQError(e.errno)}"
|
||||
else:
|
||||
pytest.fail(f"Expected {zmq.ZMQError(errno)}, no error raised")
|
||||
|
||||
|
||||
def recv(socket, *, timeout=5, flags=0, multipart=False, **kwargs):
|
||||
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
|
||||
if zmq.zmq_version_info() >= (3, 1, 0):
|
||||
# zmq 3.1 has a bug, where poll can return false positives,
|
||||
# so we wait a little bit just in case
|
||||
# See LIBZMQ-280 on JIRA
|
||||
time.sleep(0.1)
|
||||
|
||||
r, w, x = zmq.select([socket], [], [], timeout=timeout)
|
||||
assert r, "Should have received a message"
|
||||
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
|
||||
|
||||
recv = socket.recv_multipart if multipart else socket.recv
|
||||
return recv(flags=flags, **kwargs)
|
||||
|
||||
|
||||
recv_multipart = partial(recv, multipart=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_bound_pair(socket):
|
||||
def create_bound_pair(type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
|
||||
"""Create a bound socket pair using a random port."""
|
||||
s1 = socket(type1)
|
||||
s1.linger = 0
|
||||
port = s1.bind_to_random_port(interface)
|
||||
s2 = socket(type2)
|
||||
s2.linger = 0
|
||||
s2.connect(f'{interface}:{port}')
|
||||
return s1, s2
|
||||
|
||||
return create_bound_pair
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bound_pair(create_bound_pair):
|
||||
return create_bound_pair()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def push_pull(create_bound_pair):
|
||||
return create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dealer_router(create_bound_pair):
|
||||
return create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
@@ -0,0 +1,23 @@
|
||||
from zmq cimport Context, Frame, Socket, libzmq
|
||||
|
||||
|
||||
cdef inline Frame c_send_recv(Socket a, Socket b, bytes to_send):
|
||||
cdef Frame msg = Frame(to_send)
|
||||
a.send(msg)
|
||||
cdef Frame recvd = b.recv(flags=0, copy=False)
|
||||
return recvd
|
||||
|
||||
|
||||
cpdef bytes send_recv_test(bytes to_send):
|
||||
cdef Context ctx = Context()
|
||||
cdef Socket a = Socket(ctx, libzmq.ZMQ_PUSH)
|
||||
cdef Socket b = Socket(ctx, libzmq.ZMQ_PULL)
|
||||
url = 'inproc://test'
|
||||
a.bind(url)
|
||||
b.connect(url)
|
||||
cdef Frame recvd_frame = c_send_recv(a, b, to_send)
|
||||
a.close()
|
||||
b.close()
|
||||
ctx.term()
|
||||
cdef bytes recvd_bytes = recvd_frame.bytes
|
||||
return recvd_bytes
|
||||
387
asm/venv/lib/python3.11/site-packages/zmq/tests/test_asyncio.py
Normal file
387
asm/venv/lib/python3.11/site-packages/zmq/tests/test_asyncio.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""Test asyncio support"""
|
||||
# Copyright (c) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import CancelledError
|
||||
from multiprocessing import Process
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio as zaio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Context(event_loop):
|
||||
return zaio.Context
|
||||
|
||||
|
||||
def test_socket_class(context):
|
||||
with context.socket(zmq.PUSH) as s:
|
||||
assert isinstance(s, zaio.Socket)
|
||||
|
||||
|
||||
def test_instance_subclass_first(context):
|
||||
actx = zmq.asyncio.Context.instance()
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is zmq.asyncio.Context
|
||||
|
||||
|
||||
def test_instance_subclass_second(context):
|
||||
with zmq.Context.instance() as ctx:
|
||||
assert type(ctx) is zmq.Context
|
||||
with zmq.asyncio.Context.instance() as actx:
|
||||
assert type(actx) is zmq.asyncio.Context
|
||||
|
||||
|
||||
async def test_recv_multipart(context, create_bound_pair):
|
||||
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_multipart()
|
||||
assert not f.done()
|
||||
await a.send(b"hi")
|
||||
recvd = await f
|
||||
assert recvd == [b"hi"]
|
||||
|
||||
|
||||
async def test_recv(create_bound_pair):
|
||||
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv()
|
||||
assert not f1.done()
|
||||
assert not f2.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f1.done()
|
||||
assert f1.result() == b"hi"
|
||||
assert recvd == b"there"
|
||||
|
||||
|
||||
@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
|
||||
async def test_recv_timeout(push_pull):
|
||||
a, b = push_pull
|
||||
b.rcvtimeo = 100
|
||||
f1 = b.recv()
|
||||
b.rcvtimeo = 1000
|
||||
f2 = b.recv_multipart()
|
||||
with pytest.raises(zmq.Again):
|
||||
await f1
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f2.done()
|
||||
assert recvd == [b"hi", b"there"]
|
||||
|
||||
|
||||
@mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
|
||||
async def test_send_timeout(socket):
|
||||
s = socket(zmq.PUSH)
|
||||
s.sndtimeo = 100
|
||||
with pytest.raises(zmq.Again):
|
||||
await s.send(b"not going anywhere")
|
||||
|
||||
|
||||
async def test_recv_string(push_pull):
|
||||
a, b = push_pull
|
||||
f = b.recv_string()
|
||||
assert not f.done()
|
||||
msg = "πøøπ"
|
||||
await a.send_string(msg)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == msg
|
||||
assert recvd == msg
|
||||
|
||||
|
||||
async def test_recv_json(push_pull):
|
||||
a, b = push_pull
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
await a.send_json(obj)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == obj
|
||||
assert recvd == obj
|
||||
|
||||
|
||||
async def test_recv_json_cancelled(push_pull):
|
||||
a, b = push_pull
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
f.cancel()
|
||||
# cycle eventloop to allow cancel events to fire
|
||||
await asyncio.sleep(0)
|
||||
obj = dict(a=5)
|
||||
await a.send_json(obj)
|
||||
# CancelledError change in 3.8 https://bugs.python.org/issue32528
|
||||
if sys.version_info < (3, 8):
|
||||
with pytest.raises(CancelledError):
|
||||
recvd = await f
|
||||
else:
|
||||
with pytest.raises(asyncio.exceptions.CancelledError):
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
# give it a chance to incorrectly consume the event
|
||||
events = await b.poll(timeout=5)
|
||||
assert events
|
||||
await asyncio.sleep(0)
|
||||
# make sure cancelled recv didn't eat up event
|
||||
f = b.recv_json()
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == obj
|
||||
|
||||
|
||||
async def test_recv_pyobj(push_pull):
|
||||
a, b = push_pull
|
||||
f = b.recv_pyobj()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
await a.send_pyobj(obj)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == obj
|
||||
assert recvd == obj
|
||||
|
||||
|
||||
async def test_custom_serialize(create_bound_pair):
|
||||
def serialize(msg):
|
||||
frames = []
|
||||
frames.extend(msg.get("identities", []))
|
||||
content = json.dumps(msg["content"]).encode("utf8")
|
||||
frames.append(content)
|
||||
return frames
|
||||
|
||||
def deserialize(frames):
|
||||
identities = frames[:-1]
|
||||
content = json.loads(frames[-1].decode("utf8"))
|
||||
return {
|
||||
"identities": identities,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
a, b = create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
"content": {
|
||||
"a": 5,
|
||||
"b": "bee",
|
||||
}
|
||||
}
|
||||
await a.send_serialized(msg, serialize)
|
||||
recvd = await b.recv_serialized(deserialize)
|
||||
assert recvd["content"] == msg["content"]
|
||||
assert recvd["identities"]
|
||||
# bounce back, tests identities
|
||||
await b.send_serialized(recvd, serialize)
|
||||
r2 = await a.recv_serialized(deserialize)
|
||||
assert r2["content"] == msg["content"]
|
||||
assert not r2["identities"]
|
||||
|
||||
|
||||
async def test_custom_serialize_error(dealer_router):
|
||||
a, b = dealer_router
|
||||
|
||||
msg = {
|
||||
"content": {
|
||||
"a": 5,
|
||||
"b": "bee",
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError):
|
||||
await a.send_serialized(json, json.dumps)
|
||||
|
||||
await a.send(b"not json")
|
||||
with pytest.raises(TypeError):
|
||||
await b.recv_serialized(json.loads)
|
||||
|
||||
|
||||
async def test_recv_dontwait(push_pull):
|
||||
push, pull = push_pull
|
||||
f = pull.recv(zmq.DONTWAIT)
|
||||
with pytest.raises(zmq.Again):
|
||||
await f
|
||||
await push.send(b"ping")
|
||||
await pull.poll() # ensure message will be waiting
|
||||
f = pull.recv(zmq.DONTWAIT)
|
||||
assert f.done()
|
||||
msg = await f
|
||||
assert msg == b"ping"
|
||||
|
||||
|
||||
async def test_recv_cancel(push_pull):
|
||||
a, b = push_pull
|
||||
f1 = b.recv()
|
||||
f2 = b.recv_multipart()
|
||||
assert f1.cancel()
|
||||
assert f1.done()
|
||||
assert not f2.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f1.cancelled()
|
||||
assert f2.done()
|
||||
assert recvd == [b"hi", b"there"]
|
||||
|
||||
|
||||
async def test_poll(push_pull):
|
||||
a, b = push_pull
|
||||
f = b.poll(timeout=0)
|
||||
await asyncio.sleep(0)
|
||||
assert f.result() == 0
|
||||
|
||||
f = b.poll(timeout=1)
|
||||
assert not f.done()
|
||||
evt = await f
|
||||
|
||||
assert evt == 0
|
||||
|
||||
f = b.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
evt = await f
|
||||
assert evt == zmq.POLLIN
|
||||
recvd = await b.recv_multipart()
|
||||
assert recvd == [b"hi", b"there"]
|
||||
|
||||
|
||||
async def test_poll_base_socket(sockets):
|
||||
ctx = zmq.Context()
|
||||
url = "inproc://test"
|
||||
a = ctx.socket(zmq.PUSH)
|
||||
b = ctx.socket(zmq.PULL)
|
||||
sockets.extend([a, b])
|
||||
a.bind(url)
|
||||
b.connect(url)
|
||||
|
||||
poller = zaio.Poller()
|
||||
poller.register(b, zmq.POLLIN)
|
||||
|
||||
f = poller.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
a.send_multipart([b"hi", b"there"])
|
||||
evt = await f
|
||||
assert evt == [(b, zmq.POLLIN)]
|
||||
recvd = b.recv_multipart()
|
||||
assert recvd == [b"hi", b"there"]
|
||||
|
||||
|
||||
async def test_poll_on_closed_socket(push_pull):
|
||||
a, b = push_pull
|
||||
|
||||
f = b.poll(timeout=1)
|
||||
b.close()
|
||||
|
||||
# The test might stall if we try to await f directly so instead just make a few
|
||||
# passes through the event loop to schedule and execute all callbacks
|
||||
for _ in range(5):
|
||||
await asyncio.sleep(0)
|
||||
if f.cancelled():
|
||||
break
|
||||
assert f.cancelled()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="Windows does not support polling on files",
|
||||
)
|
||||
async def test_poll_raw():
|
||||
p = zaio.Poller()
|
||||
# make a pipe
|
||||
r, w = os.pipe()
|
||||
r = os.fdopen(r, "rb")
|
||||
w = os.fdopen(w, "wb")
|
||||
|
||||
# POLLOUT
|
||||
p.register(r, zmq.POLLIN)
|
||||
p.register(w, zmq.POLLOUT)
|
||||
evts = await p.poll(timeout=1)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() not in evts
|
||||
assert w.fileno() in evts
|
||||
assert evts[w.fileno()] == zmq.POLLOUT
|
||||
|
||||
# POLLIN
|
||||
p.unregister(w)
|
||||
w.write(b"x")
|
||||
w.flush()
|
||||
evts = await p.poll(timeout=1000)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() in evts
|
||||
assert evts[r.fileno()] == zmq.POLLIN
|
||||
assert r.read(1) == b"x"
|
||||
r.close()
|
||||
w.close()
|
||||
|
||||
|
||||
def test_multiple_loops(push_pull):
|
||||
a, b = push_pull
|
||||
|
||||
async def test():
|
||||
await a.send(b'buf')
|
||||
msg = await b.recv()
|
||||
assert msg == b'buf'
|
||||
|
||||
for i in range(3):
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(asyncio.wait_for(test(), timeout=10))
|
||||
loop.close()
|
||||
|
||||
|
||||
def test_shadow():
|
||||
with zmq.Context() as ctx:
|
||||
s = ctx.socket(zmq.PULL)
|
||||
async_s = zaio.Socket(s)
|
||||
assert isinstance(async_s, zaio.Socket)
|
||||
assert async_s.underlying == s.underlying
|
||||
assert async_s.type == s.type
|
||||
|
||||
|
||||
async def test_poll_leak():
|
||||
ctx = zmq.asyncio.Context()
|
||||
with ctx, ctx.socket(zmq.PULL) as s:
|
||||
assert len(s._recv_futures) == 0
|
||||
for i in range(10):
|
||||
f = asyncio.ensure_future(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN))
|
||||
f.cancel()
|
||||
await asyncio.sleep(0)
|
||||
# one more sleep allows further chained cleanup
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(s._recv_futures) == 0
|
||||
|
||||
|
||||
class ProcessForTeardownTest(Process):
|
||||
def run(self):
|
||||
"""Leave context, socket and event loop upon implicit disposal"""
|
||||
|
||||
actx = zaio.Context.instance()
|
||||
socket = actx.socket(zmq.PAIR)
|
||||
socket.bind_to_random_port("tcp://127.0.0.1")
|
||||
|
||||
async def never_ending_task(socket):
|
||||
await socket.recv() # never ever receive anything
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
coro = asyncio.wait_for(never_ending_task(socket), timeout=1)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError:
|
||||
pass # expected timeout
|
||||
else:
|
||||
assert False, "never_ending_task was completed unexpectedly"
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def test_process_teardown(request):
|
||||
proc = ProcessForTeardownTest()
|
||||
proc.start()
|
||||
request.addfinalizer(proc.terminate)
|
||||
proc.join(10) # starting new Python process may cost a lot
|
||||
assert proc.exitcode is not None, "process teardown hangs"
|
||||
assert proc.exitcode == 0, f"Python process died with code {proc.exitcode}"
|
||||
416
asm/venv/lib/python3.11/site-packages/zmq/tests/test_auth.py
Normal file
416
asm/venv/lib/python3.11/site-packages/zmq/tests/test_auth.py
Normal file
@@ -0,0 +1,416 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
import zmq.auth
|
||||
from zmq.tests import SkipTest, skip_pypy
|
||||
|
||||
try:
|
||||
import tornado
|
||||
except ImportError:
|
||||
tornado = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Context(event_loop):
|
||||
return zmq.asyncio.Context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_certs(tmpdir):
|
||||
"""Create CURVE certificates for a test"""
|
||||
|
||||
# Create temporary CURVE key pairs for this test run. We create all keys in a
|
||||
# temp directory and then move them into the appropriate private or public
|
||||
# directory.
|
||||
base_dir = str(tmpdir.join("certs").mkdir())
|
||||
keys_dir = os.path.join(base_dir, "certificates")
|
||||
public_keys_dir = os.path.join(base_dir, 'public_keys')
|
||||
secret_keys_dir = os.path.join(base_dir, 'private_keys')
|
||||
|
||||
os.mkdir(keys_dir)
|
||||
os.mkdir(public_keys_dir)
|
||||
os.mkdir(secret_keys_dir)
|
||||
|
||||
server_public_file, server_secret_file = zmq.auth.create_certificates(
|
||||
keys_dir, "server"
|
||||
)
|
||||
client_public_file, client_secret_file = zmq.auth.create_certificates(
|
||||
keys_dir, "client"
|
||||
)
|
||||
for key_file in os.listdir(keys_dir):
|
||||
if key_file.endswith(".key"):
|
||||
shutil.move(
|
||||
os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, '.')
|
||||
)
|
||||
|
||||
for key_file in os.listdir(keys_dir):
|
||||
if key_file.endswith(".key_secret"):
|
||||
shutil.move(
|
||||
os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, '.')
|
||||
)
|
||||
|
||||
return (public_keys_dir, secret_keys_dir)
|
||||
|
||||
|
||||
def load_certs(secret_keys_dir):
|
||||
"""Return server and client certificate keys"""
|
||||
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
|
||||
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
|
||||
|
||||
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
|
||||
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
|
||||
|
||||
return server_public, server_secret, client_public, client_secret
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def public_keys_dir(create_certs):
|
||||
public_keys_dir, secret_keys_dir = create_certs
|
||||
return public_keys_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def secret_keys_dir(create_certs):
|
||||
public_keys_dir, secret_keys_dir = create_certs
|
||||
return secret_keys_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def certs(secret_keys_dir):
|
||||
return load_certs(secret_keys_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def _async_setup(request, event_loop):
|
||||
"""pytest doesn't support async setup/teardown"""
|
||||
instance = request.instance
|
||||
await instance.async_setup()
|
||||
yield
|
||||
# make sure our teardown runs before the loop closes
|
||||
instance.async_teardown()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_async_setup")
|
||||
class AuthTest:
|
||||
auth = None
|
||||
|
||||
async def async_setup(self):
|
||||
self.context = zmq.asyncio.Context()
|
||||
if zmq.zmq_version_info() < (4, 0):
|
||||
raise SkipTest("security is new in libzmq 4.0")
|
||||
try:
|
||||
zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("security requires libzmq to have curve support")
|
||||
# enable debug logging while we run tests
|
||||
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
|
||||
self.auth = self.make_auth()
|
||||
await self.start_auth()
|
||||
|
||||
def async_teardown(self):
|
||||
if self.auth:
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
self.context.term()
|
||||
|
||||
def make_auth(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def start_auth(self):
|
||||
self.auth.start()
|
||||
|
||||
async def can_connect(self, server, client, timeout=1000):
|
||||
"""Check if client can connect to server using tcp transport"""
|
||||
result = False
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
msg = [b"Hello World"]
|
||||
# run poll on server twice
|
||||
# to flush spurious events
|
||||
await server.poll(100, zmq.POLLOUT)
|
||||
|
||||
if await server.poll(timeout, zmq.POLLOUT):
|
||||
try:
|
||||
await server.send_multipart(msg, zmq.NOBLOCK)
|
||||
except zmq.Again:
|
||||
warnings.warn("server set POLLOUT, but cannot send", RuntimeWarning)
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
if await client.poll(timeout):
|
||||
try:
|
||||
rcvd_msg = await client.recv_multipart(zmq.NOBLOCK)
|
||||
except zmq.Again:
|
||||
warnings.warn("client set POLLIN, but cannot recv", RuntimeWarning)
|
||||
else:
|
||||
assert rcvd_msg == msg
|
||||
result = True
|
||||
return result
|
||||
|
||||
@contextmanager
|
||||
def push_pull(self):
|
||||
with self.context.socket(zmq.PUSH) as server, self.context.socket(
|
||||
zmq.PULL
|
||||
) as client:
|
||||
server.linger = 0
|
||||
server.sndtimeo = 2000
|
||||
client.linger = 0
|
||||
client.rcvtimeo = 2000
|
||||
yield server, client
|
||||
|
||||
@contextmanager
|
||||
def curve_push_pull(self, certs, client_key="ok"):
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
with self.push_pull() as (server, client):
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
if client_key is not None:
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
if client_key == "ok":
|
||||
client.curve_serverkey = server_public
|
||||
else:
|
||||
private, public = zmq.curve_keypair()
|
||||
client.curve_serverkey = public
|
||||
yield (server, client)
|
||||
|
||||
async def test_null(self):
|
||||
"""threaded auth - NULL"""
|
||||
# A default NULL connection should always succeed, and not
|
||||
# go through our authentication infrastructure at all.
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
# use a new context, so ZAP isn't inherited
|
||||
self.context.term()
|
||||
self.context = zmq.asyncio.Context()
|
||||
|
||||
with self.push_pull() as (server, client):
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet. The client connection
|
||||
# should still be allowed.
|
||||
with self.push_pull() as (server, client):
|
||||
server.zap_domain = b'global'
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
async def test_deny(self):
|
||||
# deny 127.0.0.1, connection should fail
|
||||
self.auth.deny('127.0.0.1')
|
||||
with pytest.raises(ValueError):
|
||||
self.auth.allow("127.0.0.2")
|
||||
with self.push_pull() as (server, client):
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet.
|
||||
server.zap_domain = b'global'
|
||||
assert not await self.can_connect(server, client, timeout=100)
|
||||
|
||||
async def test_allow(self):
|
||||
# allow 127.0.0.1, connection should pass
|
||||
self.auth.allow('127.0.0.1')
|
||||
with pytest.raises(ValueError):
|
||||
self.auth.deny("127.0.0.2")
|
||||
with self.push_pull() as (server, client):
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet.
|
||||
server.zap_domain = b'global'
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"enabled, password, success",
|
||||
[
|
||||
(True, "correct", True),
|
||||
(False, "correct", False),
|
||||
(True, "incorrect", False),
|
||||
],
|
||||
)
|
||||
async def test_plain(self, enabled, password, success):
|
||||
"""threaded auth - PLAIN"""
|
||||
|
||||
# Try PLAIN authentication - without configuring server, connection should fail
|
||||
with self.push_pull() as (server, client):
|
||||
server.plain_server = True
|
||||
if password:
|
||||
client.plain_username = b'admin'
|
||||
client.plain_password = password.encode("ascii")
|
||||
if enabled:
|
||||
self.auth.configure_plain(domain='*', passwords={'admin': 'correct'})
|
||||
if success:
|
||||
assert await self.can_connect(server, client)
|
||||
else:
|
||||
assert not await self.can_connect(server, client, timeout=100)
|
||||
|
||||
# Remove authenticator and check that a normal connection works
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
with self.push_pull() as (server, client):
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_key, location, success",
|
||||
[
|
||||
('ok', zmq.auth.CURVE_ALLOW_ANY, True),
|
||||
('ok', "public_keys", True),
|
||||
('bad', "public_keys", False),
|
||||
(None, "public_keys", False),
|
||||
],
|
||||
)
|
||||
async def test_curve(self, certs, public_keys_dir, client_key, location, success):
|
||||
"""threaded auth - CURVE"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
|
||||
# Try CURVE authentication - without configuring server, connection should fail
|
||||
with self.curve_push_pull(certs, client_key=client_key) as (server, client):
|
||||
if location:
|
||||
if location == 'public_keys':
|
||||
location = public_keys_dir
|
||||
self.auth.configure_curve(domain='*', location=location)
|
||||
if success:
|
||||
assert await self.can_connect(server, client, timeout=100)
|
||||
else:
|
||||
assert not await self.can_connect(server, client, timeout=100)
|
||||
|
||||
# Remove authenticator and check that a normal connection works
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
# Try connecting using NULL and no authentication enabled, connection should pass
|
||||
with self.push_pull() as (server, client):
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
@pytest.mark.parametrize("key", ["ok", "wrong"])
|
||||
@pytest.mark.parametrize("async_", [True, False])
|
||||
async def test_curve_callback(self, certs, key, async_):
|
||||
"""threaded auth - CURVE with callback authentication"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
class CredentialsProvider:
|
||||
def __init__(self):
|
||||
if key == "ok":
|
||||
self.client = client_public
|
||||
else:
|
||||
self.client = server_public
|
||||
|
||||
def callback(self, domain, key):
|
||||
if key == self.client:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def async_callback(self, domain, key):
|
||||
await asyncio.sleep(0.1)
|
||||
if key == self.client:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
if async_:
|
||||
CredentialsProvider.callback = CredentialsProvider.async_callback
|
||||
provider = CredentialsProvider()
|
||||
self.auth.configure_curve_callback(credentials_provider=provider)
|
||||
with self.curve_push_pull(certs) as (server, client):
|
||||
if key == "ok":
|
||||
assert await self.can_connect(server, client)
|
||||
else:
|
||||
assert not await self.can_connect(server, client, timeout=200)
|
||||
|
||||
@skip_pypy
|
||||
async def test_curve_user_id(self, certs, public_keys_dir):
|
||||
"""threaded auth - CURVE"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
self.auth.configure_curve(domain='*', location=public_keys_dir)
|
||||
# reverse server-client relationship, so server is PULL
|
||||
with self.push_pull() as (client, server):
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
|
||||
assert await self.can_connect(client, server)
|
||||
|
||||
# test default user-id map
|
||||
await client.send(b'test')
|
||||
msg = await server.recv(copy=False)
|
||||
assert msg.bytes == b'test'
|
||||
try:
|
||||
user_id = msg.get('User-Id')
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
else:
|
||||
assert user_id == client_public.decode("utf8")
|
||||
|
||||
# test custom user-id map
|
||||
self.auth.curve_user_id = lambda client_key: 'custom'
|
||||
|
||||
with self.context.socket(zmq.PUSH) as client2:
|
||||
client2.curve_publickey = client_public
|
||||
client2.curve_secretkey = client_secret
|
||||
client2.curve_serverkey = server_public
|
||||
assert await self.can_connect(client2, server)
|
||||
|
||||
await client2.send(b'test2')
|
||||
msg = await server.recv(copy=False)
|
||||
assert msg.bytes == b'test2'
|
||||
try:
|
||||
user_id = msg.get('User-Id')
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
else:
|
||||
assert user_id == 'custom'
|
||||
|
||||
|
||||
class TestThreadAuthentication(AuthTest):
|
||||
"""Test authentication running in a thread"""
|
||||
|
||||
def make_auth(self):
|
||||
from zmq.auth.thread import ThreadAuthenticator
|
||||
|
||||
return ThreadAuthenticator(self.context)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == 'win32' and sys.version_info < (3, 7),
|
||||
reason="flaky event loop cleanup on windows+py36",
|
||||
)
|
||||
class TestAsyncioAuthentication(AuthTest):
|
||||
"""Test authentication running in a thread"""
|
||||
|
||||
def make_auth(self):
|
||||
from zmq.auth.asyncio import AsyncioAuthenticator
|
||||
|
||||
return AsyncioAuthenticator(self.context)
|
||||
|
||||
|
||||
async def test_ioloop_authenticator(context, event_loop, io_loop):
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
with warnings.catch_warnings():
|
||||
from zmq.auth.ioloop import IOLoopAuthenticator
|
||||
|
||||
auth = IOLoopAuthenticator(context)
|
||||
assert auth.context is context
|
||||
|
||||
loop = IOLoop(make_current=False)
|
||||
with pytest.warns(DeprecationWarning):
|
||||
auth = IOLoopAuthenticator(io_loop=loop)
|
||||
@@ -0,0 +1,303 @@
|
||||
import time
|
||||
from unittest import TestCase
|
||||
|
||||
from zmq.tests import SkipTest
|
||||
|
||||
try:
|
||||
from zmq.backend.cffi import ( # type: ignore
|
||||
IDENTITY,
|
||||
POLLIN,
|
||||
POLLOUT,
|
||||
PULL,
|
||||
PUSH,
|
||||
REP,
|
||||
REQ,
|
||||
zmq_version_info,
|
||||
)
|
||||
from zmq.backend.cffi._cffi import C, ffi
|
||||
|
||||
have_ffi_backend = True
|
||||
except ImportError:
|
||||
have_ffi_backend = False
|
||||
|
||||
|
||||
class TestCFFIBackend(TestCase):
|
||||
def setUp(self):
|
||||
if not have_ffi_backend:
|
||||
raise SkipTest('CFFI not available')
|
||||
|
||||
def test_zmq_version_info(self):
|
||||
version = zmq_version_info()
|
||||
|
||||
assert version[0] in range(2, 11)
|
||||
|
||||
def test_zmq_ctx_new_destroy(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
assert ctx != ffi.NULL
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_socket_open_close(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, PUSH)
|
||||
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_setsockopt(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, PUSH)
|
||||
|
||||
identity = ffi.new('char[3]', b'zmq')
|
||||
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
|
||||
|
||||
assert ret == 0
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_getsockopt(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, PUSH)
|
||||
|
||||
identity = ffi.new('char[]', b'zmq')
|
||||
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
|
||||
assert ret == 0
|
||||
|
||||
option_len = ffi.new('size_t*', 3)
|
||||
option = ffi.new('char[3]')
|
||||
ret = C.zmq_getsockopt(socket, IDENTITY, ffi.cast('void*', option), option_len)
|
||||
|
||||
assert ret == 0
|
||||
assert ffi.string(ffi.cast('char*', option))[0:1] == b"z"
|
||||
assert ffi.string(ffi.cast('char*', option))[1:2] == b"m"
|
||||
assert ffi.string(ffi.cast('char*', option))[2:3] == b"q"
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_bind(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, 8)
|
||||
|
||||
assert 0 == C.zmq_bind(socket, b'tcp://*:4444')
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_bind_connect(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
socket1 = C.zmq_socket(ctx, PUSH)
|
||||
socket2 = C.zmq_socket(ctx, PULL)
|
||||
|
||||
assert 0 == C.zmq_bind(socket1, b'tcp://*:4444')
|
||||
assert 0 == C.zmq_connect(socket2, b'tcp://127.0.0.1:4444')
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket1
|
||||
assert ffi.NULL != socket2
|
||||
assert 0 == C.zmq_close(socket1)
|
||||
assert 0 == C.zmq_close(socket2)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_msg_init_close(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert 0 == C.zmq_msg_init(zmq_msg)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_msg_init_size(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert 0 == C.zmq_msg_init_size(zmq_msg, 10)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_msg_init_data(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
assert 0 == C.zmq_msg_init_data(
|
||||
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
|
||||
)
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_msg_data(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[]', b'Hello')
|
||||
assert 0 == C.zmq_msg_init_data(
|
||||
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
|
||||
)
|
||||
|
||||
data = C.zmq_msg_data(zmq_msg)
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert ffi.string(ffi.cast("char*", data)) == b'Hello'
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_send(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
sender = C.zmq_socket(ctx, REQ)
|
||||
receiver = C.zmq_socket(ctx, REP)
|
||||
|
||||
assert 0 == C.zmq_bind(receiver, b'tcp://*:7777')
|
||||
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:7777')
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
C.zmq_msg_init_data(
|
||||
zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
ffi.cast('size_t', 5),
|
||||
ffi.NULL,
|
||||
ffi.NULL,
|
||||
)
|
||||
|
||||
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
assert C.zmq_close(sender) == 0
|
||||
assert C.zmq_close(receiver) == 0
|
||||
assert C.zmq_ctx_destroy(ctx) == 0
|
||||
|
||||
def test_zmq_recv(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
sender = C.zmq_socket(ctx, REQ)
|
||||
receiver = C.zmq_socket(ctx, REP)
|
||||
|
||||
assert 0 == C.zmq_bind(receiver, b'tcp://*:2222')
|
||||
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:2222')
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
C.zmq_msg_init_data(
|
||||
zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
ffi.cast('size_t', 5),
|
||||
ffi.NULL,
|
||||
ffi.NULL,
|
||||
)
|
||||
|
||||
zmq_msg2 = ffi.new('zmq_msg_t*')
|
||||
C.zmq_msg_init(zmq_msg2)
|
||||
|
||||
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
|
||||
assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0)
|
||||
assert 5 == C.zmq_msg_size(zmq_msg2)
|
||||
assert (
|
||||
b"Hello"
|
||||
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
|
||||
)
|
||||
assert C.zmq_close(sender) == 0
|
||||
assert C.zmq_close(receiver) == 0
|
||||
assert C.zmq_ctx_destroy(ctx) == 0
|
||||
|
||||
def test_zmq_poll(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
sender = C.zmq_socket(ctx, REQ)
|
||||
receiver = C.zmq_socket(ctx, REP)
|
||||
|
||||
r1 = C.zmq_bind(receiver, b'tcp://*:3333')
|
||||
r2 = C.zmq_connect(sender, b'tcp://127.0.0.1:3333')
|
||||
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
C.zmq_msg_init_data(
|
||||
zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
ffi.cast('size_t', 5),
|
||||
ffi.NULL,
|
||||
ffi.NULL,
|
||||
)
|
||||
|
||||
receiver_pollitem = ffi.new('zmq_pollitem_t*')
|
||||
receiver_pollitem.socket = receiver
|
||||
receiver_pollitem.fd = 0
|
||||
receiver_pollitem.events = POLLIN | POLLOUT
|
||||
receiver_pollitem.revents = 0
|
||||
|
||||
ret = C.zmq_poll(ffi.NULL, 0, 0)
|
||||
assert ret == 0
|
||||
|
||||
ret = C.zmq_poll(receiver_pollitem, 1, 0)
|
||||
assert ret == 0
|
||||
|
||||
ret = C.zmq_msg_send(zmq_msg, sender, 0)
|
||||
print(ffi.string(C.zmq_strerror(C.zmq_errno())))
|
||||
assert ret == 5
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
ret = C.zmq_poll(receiver_pollitem, 1, 0)
|
||||
assert ret == 1
|
||||
|
||||
assert int(receiver_pollitem.revents) & POLLIN
|
||||
assert not int(receiver_pollitem.revents) & POLLOUT
|
||||
|
||||
zmq_msg2 = ffi.new('zmq_msg_t*')
|
||||
C.zmq_msg_init(zmq_msg2)
|
||||
|
||||
ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0)
|
||||
assert ret_recv == 5
|
||||
|
||||
assert 5 == C.zmq_msg_size(zmq_msg2)
|
||||
assert (
|
||||
b"Hello"
|
||||
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
|
||||
)
|
||||
|
||||
sender_pollitem = ffi.new('zmq_pollitem_t*')
|
||||
sender_pollitem.socket = sender
|
||||
sender_pollitem.fd = 0
|
||||
sender_pollitem.events = POLLIN | POLLOUT
|
||||
sender_pollitem.revents = 0
|
||||
|
||||
ret = C.zmq_poll(sender_pollitem, 1, 0)
|
||||
assert ret == 0
|
||||
|
||||
zmq_msg_again = ffi.new('zmq_msg_t*')
|
||||
message_again = ffi.new('char[11]', b'Hello Again')
|
||||
|
||||
C.zmq_msg_init_data(
|
||||
zmq_msg_again,
|
||||
ffi.cast('void*', message_again),
|
||||
ffi.cast('size_t', 11),
|
||||
ffi.NULL,
|
||||
ffi.NULL,
|
||||
)
|
||||
|
||||
assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0)
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
assert 0 <= C.zmq_poll(sender_pollitem, 1, 0)
|
||||
assert int(sender_pollitem.revents) & POLLIN
|
||||
assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0)
|
||||
assert 11 == C.zmq_msg_size(zmq_msg2)
|
||||
assert (
|
||||
b"Hello Again"
|
||||
== ffi.buffer(C.zmq_msg_data(zmq_msg2), int(C.zmq_msg_size(zmq_msg2)))[:]
|
||||
)
|
||||
assert 0 == C.zmq_close(sender)
|
||||
assert 0 == C.zmq_close(receiver)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg2)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg_again)
|
||||
@@ -0,0 +1,31 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
import zmq.constants
|
||||
|
||||
|
||||
def test_constants():
|
||||
assert zmq.POLLIN is zmq.PollEvent.POLLIN
|
||||
assert zmq.PUSH is zmq.SocketType.PUSH
|
||||
assert zmq.constants.SUBSCRIBE is zmq.SocketOption.SUBSCRIBE
|
||||
assert (
|
||||
zmq.RECONNECT_STOP_AFTER_DISCONNECT
|
||||
is zmq.constants.ReconnectStop.AFTER_DISCONNECT
|
||||
)
|
||||
|
||||
|
||||
def test_socket_options():
|
||||
assert zmq.IDENTITY is zmq.SocketOption.ROUTING_ID
|
||||
assert zmq.IDENTITY._opt_type is zmq.constants._OptType.bytes
|
||||
assert zmq.AFFINITY._opt_type is zmq.constants._OptType.int64
|
||||
assert zmq.CURVE_SERVER._opt_type is zmq.constants._OptType.int
|
||||
assert zmq.FD._opt_type is zmq.constants._OptType.fd
|
||||
|
||||
|
||||
@pytest.mark.parametrize("event_name", list(zmq.Event.__members__))
|
||||
def test_event_reprs(event_name):
|
||||
event = getattr(zmq.Event, event_name)
|
||||
assert event_name in repr(event)
|
||||
425
asm/venv/lib/python3.11/site-packages/zmq/tests/test_context.py
Normal file
425
asm/venv/lib/python3.11/site-packages/zmq/tests/test_context.py
Normal file
@@ -0,0 +1,425 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from queue import Queue
|
||||
from threading import Event, Thread
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, GreenTest, SkipTest
|
||||
|
||||
|
||||
class KwargTestSocket(zmq.Socket):
|
||||
test_kwarg_value = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.test_kwarg_value = kwargs.pop('test_kwarg', None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class KwargTestContext(zmq.Context):
|
||||
_socket_class = KwargTestSocket
|
||||
|
||||
|
||||
class TestContext(BaseZMQTestCase):
|
||||
def test_init(self):
|
||||
c1 = self.Context()
|
||||
assert isinstance(c1, self.Context)
|
||||
c1.term()
|
||||
c2 = self.Context()
|
||||
assert isinstance(c2, self.Context)
|
||||
c2.term()
|
||||
c3 = self.Context()
|
||||
assert isinstance(c3, self.Context)
|
||||
c3.term()
|
||||
|
||||
_repr_cls = "zmq.Context"
|
||||
|
||||
def test_repr(self):
|
||||
with self.Context() as ctx:
|
||||
assert f'{self._repr_cls}()' in repr(ctx)
|
||||
assert 'closed' not in repr(ctx)
|
||||
with ctx.socket(zmq.PUSH) as push:
|
||||
assert f'{self._repr_cls}(1 socket)' in repr(ctx)
|
||||
with ctx.socket(zmq.PULL) as pull:
|
||||
assert f'{self._repr_cls}(2 sockets)' in repr(ctx)
|
||||
assert f'{self._repr_cls}()' in repr(ctx)
|
||||
assert 'closed' in repr(ctx)
|
||||
|
||||
def test_dir(self):
|
||||
ctx = self.Context()
|
||||
assert 'socket' in dir(ctx)
|
||||
if zmq.zmq_version_info() > (3,):
|
||||
assert 'IO_THREADS' in dir(ctx)
|
||||
ctx.term()
|
||||
|
||||
@mark.skipif(mock is None, reason="requires unittest.mock")
|
||||
def test_mockable(self):
|
||||
m = mock.Mock(spec=self.context)
|
||||
|
||||
def test_term(self):
|
||||
c = self.Context()
|
||||
c.term()
|
||||
assert c.closed
|
||||
|
||||
def test_context_manager(self):
|
||||
with pytest.warns(ResourceWarning):
|
||||
with self.Context() as ctx:
|
||||
s = ctx.socket(zmq.PUSH)
|
||||
# context exit destroys sockets
|
||||
assert s.closed
|
||||
assert ctx.closed
|
||||
|
||||
def test_fail_init(self):
|
||||
self.assertRaisesErrno(zmq.EINVAL, self.Context, -1)
|
||||
|
||||
def test_term_hang(self):
|
||||
rep, req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
|
||||
req.setsockopt(zmq.LINGER, 0)
|
||||
req.send(b'hello', copy=False)
|
||||
req.close()
|
||||
rep.close()
|
||||
self.context.term()
|
||||
|
||||
def test_instance(self):
|
||||
ctx = self.Context.instance()
|
||||
c2 = self.Context.instance(io_threads=2)
|
||||
assert c2 is ctx
|
||||
c2.term()
|
||||
c3 = self.Context.instance()
|
||||
c4 = self.Context.instance()
|
||||
assert not c3 is c2
|
||||
assert not c3.closed
|
||||
assert c3 is c4
|
||||
|
||||
def test_instance_subclass_first(self):
|
||||
self.context.term()
|
||||
|
||||
class SubContext(zmq.Context):
|
||||
pass
|
||||
|
||||
sctx = SubContext.instance()
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.term()
|
||||
sctx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(sctx) is SubContext
|
||||
|
||||
def test_instance_subclass_second(self):
|
||||
self.context.term()
|
||||
|
||||
class SubContextInherit(zmq.Context):
|
||||
pass
|
||||
|
||||
class SubContextNoInherit(zmq.Context):
|
||||
_instance = None
|
||||
|
||||
ctx = zmq.Context.instance()
|
||||
sctx = SubContextInherit.instance()
|
||||
sctx2 = SubContextNoInherit.instance()
|
||||
ctx.term()
|
||||
sctx.term()
|
||||
sctx2.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(sctx) is zmq.Context
|
||||
assert type(sctx2) is SubContextNoInherit
|
||||
|
||||
def test_instance_threadsafe(self):
|
||||
self.context.term() # clear default context
|
||||
|
||||
q = Queue()
|
||||
|
||||
# slow context initialization,
|
||||
# to ensure that we are both trying to create one at the same time
|
||||
class SlowContext(self.Context):
|
||||
def __init__(self, *a, **kw):
|
||||
time.sleep(1)
|
||||
super().__init__(*a, **kw)
|
||||
|
||||
def f():
|
||||
q.put(SlowContext.instance())
|
||||
|
||||
# call ctx.instance() in several threads at once
|
||||
N = 16
|
||||
threads = [Thread(target=f) for i in range(N)]
|
||||
[t.start() for t in threads]
|
||||
# also call it in the main thread (not first)
|
||||
ctx = SlowContext.instance()
|
||||
assert isinstance(ctx, SlowContext)
|
||||
# check that all the threads got the same context
|
||||
for i in range(N):
|
||||
thread_ctx = q.get(timeout=5)
|
||||
assert thread_ctx is ctx
|
||||
# cleanup
|
||||
ctx.term()
|
||||
[t.join(timeout=5) for t in threads]
|
||||
|
||||
def test_socket_passes_kwargs(self):
|
||||
test_kwarg_value = 'testing one two three'
|
||||
with KwargTestContext() as ctx:
|
||||
with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket:
|
||||
assert socket.test_kwarg_value is test_kwarg_value
|
||||
|
||||
def test_socket_class_arg(self):
|
||||
class CustomSocket(zmq.Socket):
|
||||
pass
|
||||
|
||||
with self.Context() as ctx:
|
||||
with ctx.socket(zmq.PUSH, socket_class=CustomSocket) as s:
|
||||
assert isinstance(s, CustomSocket)
|
||||
|
||||
def test_many_sockets(self):
|
||||
"""opening and closing many sockets shouldn't cause problems"""
|
||||
ctx = self.Context()
|
||||
for i in range(16):
|
||||
sockets = [ctx.socket(zmq.REP) for i in range(65)]
|
||||
[s.close() for s in sockets]
|
||||
# give the reaper a chance
|
||||
time.sleep(1e-2)
|
||||
ctx.term()
|
||||
|
||||
def test_sockopts(self):
|
||||
"""setting socket options with ctx attributes"""
|
||||
ctx = self.Context()
|
||||
ctx.linger = 5
|
||||
assert ctx.linger == 5
|
||||
s = ctx.socket(zmq.REQ)
|
||||
assert s.linger == 5
|
||||
assert s.getsockopt(zmq.LINGER) == 5
|
||||
s.close()
|
||||
# check that subscribe doesn't get set on sockets that don't subscribe:
|
||||
ctx.subscribe = b''
|
||||
s = ctx.socket(zmq.REQ)
|
||||
s.close()
|
||||
|
||||
ctx.term()
|
||||
|
||||
@mark.skipif(sys.platform.startswith('win'), reason='Segfaults on Windows')
|
||||
def test_destroy(self):
|
||||
"""Context.destroy should close sockets"""
|
||||
ctx = self.Context()
|
||||
sockets = [ctx.socket(zmq.REP) for i in range(65)]
|
||||
|
||||
# close half of the sockets
|
||||
[s.close() for s in sockets[::2]]
|
||||
|
||||
ctx.destroy()
|
||||
# reaper is not instantaneous
|
||||
time.sleep(1e-2)
|
||||
for s in sockets:
|
||||
assert s.closed
|
||||
|
||||
def test_destroy_linger(self):
|
||||
"""Context.destroy should set linger on closing sockets"""
|
||||
req, rep = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
req.send(b'hi')
|
||||
time.sleep(1e-2)
|
||||
self.context.destroy(linger=0)
|
||||
# reaper is not instantaneous
|
||||
time.sleep(1e-2)
|
||||
for s in (req, rep):
|
||||
assert s.closed
|
||||
|
||||
def test_term_noclose(self):
|
||||
"""Context.term won't close sockets"""
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.REQ)
|
||||
assert not s.closed
|
||||
t = Thread(target=ctx.term)
|
||||
t.start()
|
||||
t.join(timeout=0.1)
|
||||
assert t.is_alive(), "Context should be waiting"
|
||||
s.close()
|
||||
t.join(timeout=0.1)
|
||||
assert not t.is_alive(), "Context should have closed"
|
||||
|
||||
def test_gc(self):
|
||||
"""test close&term by garbage collection alone"""
|
||||
if PYPY:
|
||||
raise SkipTest("GC doesn't work ")
|
||||
|
||||
# test credit @dln (GH #137):
|
||||
def gcf():
|
||||
def inner():
|
||||
ctx = self.Context()
|
||||
ctx.socket(zmq.PUSH)
|
||||
|
||||
# can't seem to catch these with pytest.warns(ResourceWarning)
|
||||
inner()
|
||||
gc.collect()
|
||||
|
||||
t = Thread(target=gcf)
|
||||
t.start()
|
||||
t.join(timeout=1)
|
||||
assert not t.is_alive(), "Garbage collection should have cleaned up context"
|
||||
|
||||
def test_cyclic_destroy(self):
|
||||
"""ctx.destroy should succeed when cyclic ref prevents gc"""
|
||||
|
||||
# test credit @dln (GH #137):
|
||||
class CyclicReference:
|
||||
def __init__(self, parent=None):
|
||||
self.parent = parent
|
||||
|
||||
def crash(self, sock):
|
||||
self.sock = sock
|
||||
self.child = CyclicReference(self)
|
||||
|
||||
def crash_zmq():
|
||||
ctx = self.Context()
|
||||
sock = ctx.socket(zmq.PULL)
|
||||
c = CyclicReference()
|
||||
c.crash(sock)
|
||||
ctx.destroy()
|
||||
|
||||
crash_zmq()
|
||||
|
||||
def test_term_thread(self):
|
||||
"""ctx.term should not crash active threads (#139)"""
|
||||
ctx = self.Context()
|
||||
evt = Event()
|
||||
evt.clear()
|
||||
|
||||
def block():
|
||||
s = ctx.socket(zmq.REP)
|
||||
s.bind_to_random_port('tcp://127.0.0.1')
|
||||
evt.set()
|
||||
try:
|
||||
s.recv()
|
||||
except zmq.ZMQError as e:
|
||||
assert e.errno == zmq.ETERM
|
||||
return
|
||||
finally:
|
||||
s.close()
|
||||
self.fail("recv should have been interrupted with ETERM")
|
||||
|
||||
t = Thread(target=block)
|
||||
t.start()
|
||||
|
||||
evt.wait(1)
|
||||
assert evt.is_set(), "sync event never fired"
|
||||
time.sleep(0.01)
|
||||
ctx.term()
|
||||
t.join(timeout=1)
|
||||
assert not t.is_alive(), "term should have interrupted s.recv()"
|
||||
|
||||
def test_destroy_no_sockets(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.bind_to_random_port('tcp://127.0.0.1')
|
||||
s.close()
|
||||
ctx.destroy()
|
||||
assert s.closed
|
||||
assert ctx.closed
|
||||
|
||||
def test_ctx_opts(self):
|
||||
if zmq.zmq_version_info() < (3,):
|
||||
raise SkipTest("context options require libzmq 3")
|
||||
ctx = self.Context()
|
||||
ctx.set(zmq.MAX_SOCKETS, 2)
|
||||
assert ctx.get(zmq.MAX_SOCKETS) == 2
|
||||
ctx.max_sockets = 100
|
||||
assert ctx.max_sockets == 100
|
||||
assert ctx.get(zmq.MAX_SOCKETS) == 100
|
||||
|
||||
def test_copy(self):
|
||||
c1 = self.Context()
|
||||
c2 = copy.copy(c1)
|
||||
c2b = copy.deepcopy(c1)
|
||||
c3 = copy.deepcopy(c2)
|
||||
assert c2._shadow
|
||||
assert c3._shadow
|
||||
assert c1.underlying == c2.underlying
|
||||
assert c1.underlying == c3.underlying
|
||||
assert c1.underlying == c2b.underlying
|
||||
s = c3.socket(zmq.PUB)
|
||||
s.close()
|
||||
c1.term()
|
||||
|
||||
def test_shadow(self):
|
||||
ctx = self.Context()
|
||||
ctx2 = self.Context.shadow(ctx.underlying)
|
||||
assert ctx.underlying == ctx2.underlying
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.close()
|
||||
del ctx2
|
||||
assert not ctx.closed
|
||||
s = ctx.socket(zmq.PUB)
|
||||
ctx2 = self.Context.shadow(ctx)
|
||||
with ctx2.socket(zmq.PUB) as s2:
|
||||
pass
|
||||
|
||||
assert s2.closed
|
||||
assert not s.closed
|
||||
s.close()
|
||||
|
||||
ctx3 = self.Context(ctx)
|
||||
assert ctx3.underlying == ctx.underlying
|
||||
del ctx3
|
||||
assert not ctx.closed
|
||||
|
||||
ctx.term()
|
||||
self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB)
|
||||
del ctx2
|
||||
|
||||
def test_shadow_pyczmq(self):
|
||||
try:
|
||||
from pyczmq import zctx, zsocket, zstr
|
||||
except Exception:
|
||||
raise SkipTest("Requires pyczmq")
|
||||
|
||||
ctx = zctx.new()
|
||||
a = zsocket.new(ctx, zmq.PUSH)
|
||||
zsocket.bind(a, "inproc://a")
|
||||
ctx2 = self.Context.shadow_pyczmq(ctx)
|
||||
b = ctx2.socket(zmq.PULL)
|
||||
b.connect("inproc://a")
|
||||
zstr.send(a, b'hi')
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == b'hi'
|
||||
b.close()
|
||||
|
||||
@mark.skipif(sys.platform.startswith('win'), reason='No fork on Windows')
|
||||
def test_fork_instance(self):
|
||||
ctx = self.Context.instance()
|
||||
parent_ctx_id = id(ctx)
|
||||
r_fd, w_fd = os.pipe()
|
||||
reader = os.fdopen(r_fd, 'r')
|
||||
child_pid = os.fork()
|
||||
if child_pid == 0:
|
||||
ctx = self.Context.instance()
|
||||
writer = os.fdopen(w_fd, 'w')
|
||||
child_ctx_id = id(ctx)
|
||||
ctx.term()
|
||||
writer.write(str(child_ctx_id) + "\n")
|
||||
writer.flush()
|
||||
writer.close()
|
||||
os._exit(0)
|
||||
else:
|
||||
os.close(w_fd)
|
||||
|
||||
child_id_s = reader.readline()
|
||||
reader.close()
|
||||
assert child_id_s
|
||||
assert int(child_id_s) != parent_ctx_id
|
||||
ctx.term()
|
||||
|
||||
|
||||
if False: # disable green context tests
|
||||
|
||||
class TestContextGreen(GreenTest, TestContext):
|
||||
"""gevent subclass of context tests"""
|
||||
|
||||
# skip tests that use real threads:
|
||||
test_gc = GreenTest.skip_green
|
||||
test_term_thread = GreenTest.skip_green
|
||||
test_destroy_linger = GreenTest.skip_green
|
||||
_repr_cls = "zmq.green.Context"
|
||||
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
|
||||
pyximport = pytest.importorskip("pyximport")
|
||||
|
||||
HERE = os.path.dirname(__file__)
|
||||
cython_ext = os.path.join(HERE, "cython_ext.pyx")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.path.exists(cython_ext),
|
||||
reason=f"Requires cython test file {cython_ext}",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
'zmq.backend.cython' not in sys.modules, reason="Requires cython backend"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'), reason="Don't try runtime Cython on Windows"
|
||||
)
|
||||
@pytest.mark.parametrize('language_level', [3, 2])
|
||||
def test_cython(language_level, request, tmpdir):
|
||||
assert 'zmq.tests.cython_ext' not in sys.modules
|
||||
|
||||
importers = pyximport.install(
|
||||
setup_args=dict(include_dirs=zmq.get_includes()),
|
||||
language_level=language_level,
|
||||
build_dir=str(tmpdir),
|
||||
)
|
||||
|
||||
cython_ext = None
|
||||
|
||||
def unimport():
|
||||
pyximport.uninstall(*importers)
|
||||
sys.modules.pop('zmq.tests.cython_ext', None)
|
||||
|
||||
request.addfinalizer(unimport)
|
||||
|
||||
# this import tests the compilation
|
||||
from . import cython_ext
|
||||
|
||||
assert hasattr(cython_ext, 'send_recv_test')
|
||||
|
||||
# call the compiled function
|
||||
# this shouldn't do much
|
||||
msg = b'my msg'
|
||||
received = cython_ext.send_recv_test(msg)
|
||||
assert received == msg
|
||||
@@ -0,0 +1,396 @@
|
||||
import threading
|
||||
|
||||
from pytest import fixture, raises
|
||||
|
||||
import zmq
|
||||
from zmq.decorators import context, socket
|
||||
from zmq.tests import BaseZMQTestCase, term_context
|
||||
|
||||
##############################################
|
||||
# Test cases for @context
|
||||
##############################################
|
||||
|
||||
|
||||
@fixture(autouse=True)
|
||||
def term_context_instance(request):
|
||||
request.addfinalizer(lambda: term_context(zmq.Context.instance(), timeout=10))
|
||||
|
||||
|
||||
def test_ctx():
|
||||
@context()
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_orig_args():
|
||||
@context()
|
||||
def f(foo, bar, ctx, baz=None):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert foo == 42
|
||||
assert bar is True
|
||||
assert baz == 'mock'
|
||||
|
||||
f(42, True, baz='mock')
|
||||
|
||||
|
||||
def test_ctx_arg_naming():
|
||||
@context('myctx')
|
||||
def test(myctx):
|
||||
assert isinstance(myctx, zmq.Context), myctx
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_args():
|
||||
@context('ctx', 5)
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_arg_kwarg():
|
||||
@context('ctx', io_threads=5)
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_kw_naming():
|
||||
@context(name='myctx')
|
||||
def test(myctx):
|
||||
assert isinstance(myctx, zmq.Context), myctx
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_kwargs():
|
||||
@context(name='ctx', io_threads=5)
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_kwargs_default():
|
||||
@context(name='ctx', io_threads=5)
|
||||
def test(ctx=None):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_keyword_miss():
|
||||
@context(name='ctx')
|
||||
def test(other_name):
|
||||
pass # the keyword ``ctx`` not found
|
||||
|
||||
with raises(TypeError):
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_multi_assign():
|
||||
@context(name='ctx')
|
||||
def test(ctx):
|
||||
pass # explosion
|
||||
|
||||
with raises(TypeError):
|
||||
test('mock')
|
||||
|
||||
|
||||
def test_ctx_reinit():
|
||||
result = {'foo': None, 'bar': None}
|
||||
|
||||
@context()
|
||||
def f(key, ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
result[key] = ctx
|
||||
|
||||
foo_t = threading.Thread(target=f, args=('foo',))
|
||||
bar_t = threading.Thread(target=f, args=('bar',))
|
||||
|
||||
foo_t.start()
|
||||
bar_t.start()
|
||||
|
||||
foo_t.join()
|
||||
bar_t.join()
|
||||
|
||||
assert result['foo'] is not None, result
|
||||
assert result['bar'] is not None, result
|
||||
assert result['foo'] is not result['bar'], result
|
||||
|
||||
|
||||
def test_ctx_multi_thread():
|
||||
@context()
|
||||
@context()
|
||||
def f(foo, bar):
|
||||
assert isinstance(foo, zmq.Context), foo
|
||||
assert isinstance(bar, zmq.Context), bar
|
||||
|
||||
assert len(set(map(id, [foo, bar]))) == 2, set(map(id, [foo, bar]))
|
||||
|
||||
threads = [threading.Thread(target=f) for i in range(8)]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
|
||||
##############################################
|
||||
# Test cases for @socket
|
||||
##############################################
|
||||
|
||||
|
||||
def test_ctx_skt():
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
def test(ctx, skt):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
assert skt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_name():
|
||||
@context()
|
||||
@socket('myskt', zmq.PUB)
|
||||
def test(ctx, myskt):
|
||||
assert isinstance(myskt, zmq.Socket), myskt
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert myskt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_kwarg():
|
||||
@context()
|
||||
@socket(zmq.PUB, name='myskt')
|
||||
def test(ctx, myskt):
|
||||
assert isinstance(myskt, zmq.Socket), myskt
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert myskt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_skt_name():
|
||||
@context('ctx')
|
||||
@socket('skt', zmq.PUB, context_name='ctx')
|
||||
def test(ctx, skt):
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert skt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_default_ctx():
|
||||
@socket(zmq.PUB)
|
||||
def test(skt):
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
assert skt.context is zmq.Context.instance()
|
||||
assert skt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_reinit():
|
||||
result = {'foo': None, 'bar': None}
|
||||
|
||||
@socket(zmq.PUB)
|
||||
def f(key, skt):
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
|
||||
result[key] = skt
|
||||
|
||||
foo_t = threading.Thread(target=f, args=('foo',))
|
||||
bar_t = threading.Thread(target=f, args=('bar',))
|
||||
|
||||
foo_t.start()
|
||||
bar_t.start()
|
||||
|
||||
foo_t.join()
|
||||
bar_t.join()
|
||||
|
||||
assert result['foo'] is not None, result
|
||||
assert result['bar'] is not None, result
|
||||
assert result['foo'] is not result['bar'], result
|
||||
|
||||
|
||||
def test_ctx_skt_reinit():
|
||||
result = {'foo': {'ctx': None, 'skt': None}, 'bar': {'ctx': None, 'skt': None}}
|
||||
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
def f(key, ctx, skt):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
|
||||
result[key]['ctx'] = ctx
|
||||
result[key]['skt'] = skt
|
||||
|
||||
foo_t = threading.Thread(target=f, args=('foo',))
|
||||
bar_t = threading.Thread(target=f, args=('bar',))
|
||||
|
||||
foo_t.start()
|
||||
bar_t.start()
|
||||
|
||||
foo_t.join()
|
||||
bar_t.join()
|
||||
|
||||
assert result['foo']['ctx'] is not None, result
|
||||
assert result['foo']['skt'] is not None, result
|
||||
assert result['bar']['ctx'] is not None, result
|
||||
assert result['bar']['skt'] is not None, result
|
||||
assert result['foo']['ctx'] is not result['bar']['ctx'], result
|
||||
assert result['foo']['skt'] is not result['bar']['skt'], result
|
||||
|
||||
|
||||
def test_skt_type_miss():
|
||||
@context()
|
||||
@socket('myskt')
|
||||
def f(ctx, myskt):
|
||||
pass # the socket type is missing
|
||||
|
||||
with raises(TypeError):
|
||||
f()
|
||||
|
||||
|
||||
def test_multi_skts():
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
@socket(zmq.PUSH)
|
||||
def test(pub, sub, push):
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert isinstance(push, zmq.Socket), push
|
||||
|
||||
assert pub.context is zmq.Context.instance()
|
||||
assert sub.context is zmq.Context.instance()
|
||||
assert push.context is zmq.Context.instance()
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
assert push.type == zmq.PUSH
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_multi_skts_single_ctx():
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
@socket(zmq.PUSH)
|
||||
def test(ctx, pub, sub, push):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert isinstance(push, zmq.Socket), push
|
||||
|
||||
assert pub.context is ctx
|
||||
assert sub.context is ctx
|
||||
assert push.context is ctx
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
assert push.type == zmq.PUSH
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_multi_skts_with_name():
|
||||
@socket('foo', zmq.PUSH)
|
||||
@socket('bar', zmq.SUB)
|
||||
@socket('baz', zmq.PUB)
|
||||
def test(foo, bar, baz):
|
||||
assert isinstance(foo, zmq.Socket), foo
|
||||
assert isinstance(bar, zmq.Socket), bar
|
||||
assert isinstance(baz, zmq.Socket), baz
|
||||
|
||||
assert foo.context is zmq.Context.instance()
|
||||
assert bar.context is zmq.Context.instance()
|
||||
assert baz.context is zmq.Context.instance()
|
||||
|
||||
assert foo.type == zmq.PUSH
|
||||
assert bar.type == zmq.SUB
|
||||
assert baz.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_func_return():
|
||||
@context()
|
||||
def f(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
return 'something'
|
||||
|
||||
assert f() == 'something'
|
||||
|
||||
|
||||
def test_skt_multi_thread():
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
@socket(zmq.PUSH)
|
||||
def f(pub, sub, push):
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert isinstance(push, zmq.Socket), push
|
||||
|
||||
assert pub.context is zmq.Context.instance()
|
||||
assert sub.context is zmq.Context.instance()
|
||||
assert push.context is zmq.Context.instance()
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
assert push.type == zmq.PUSH
|
||||
|
||||
assert len(set(map(id, [pub, sub, push]))) == 3
|
||||
|
||||
threads = [threading.Thread(target=f) for i in range(8)]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
|
||||
class TestMethodDecorators(BaseZMQTestCase):
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
def multi_skts_method(self, ctx, pub, sub, foo='bar'):
|
||||
assert isinstance(self, TestMethodDecorators), self
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert foo == 'bar'
|
||||
|
||||
assert pub.context is ctx
|
||||
assert sub.context is ctx
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
|
||||
def test_multi_skts_method(self):
|
||||
self.multi_skts_method()
|
||||
|
||||
def test_multi_skts_method_other_args(self):
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
def f(foo, pub, sub, bar=None):
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
|
||||
assert foo == 'mock'
|
||||
assert bar == 'fake'
|
||||
|
||||
assert pub.context is zmq.Context.instance()
|
||||
assert sub.context is zmq.Context.instance()
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
|
||||
f('mock', bar='fake')
|
||||
168
asm/venv/lib/python3.11/site-packages/zmq/tests/test_device.py
Normal file
168
asm/venv/lib/python3.11/site-packages/zmq/tests/test_device.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq import devices
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, GreenTest, SkipTest, have_gevent
|
||||
|
||||
if PYPY:
|
||||
# cleanup of shared Context doesn't work on PyPy
|
||||
devices.Device.context_factory = zmq.Context
|
||||
|
||||
|
||||
class TestDevice(BaseZMQTestCase):
|
||||
def test_device_types(self):
|
||||
for devtype in (zmq.STREAMER, zmq.FORWARDER, zmq.QUEUE):
|
||||
dev = devices.Device(devtype, zmq.PAIR, zmq.PAIR)
|
||||
assert dev.device_type == devtype
|
||||
del dev
|
||||
|
||||
def test_device_attributes(self):
|
||||
dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB)
|
||||
assert dev.in_type == zmq.SUB
|
||||
assert dev.out_type == zmq.PUB
|
||||
assert dev.device_type == zmq.QUEUE
|
||||
assert dev.daemon == True
|
||||
del dev
|
||||
|
||||
def test_single_socket_forwarder_connect(self):
|
||||
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
|
||||
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
req = self.context.socket(zmq.REQ)
|
||||
port = req.bind_to_random_port('tcp://127.0.0.1')
|
||||
dev.connect_in('tcp://127.0.0.1:%i' % port)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
req.send(msg)
|
||||
assert msg == self.recv(req)
|
||||
del dev
|
||||
req.close()
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
req = self.context.socket(zmq.REQ)
|
||||
port = req.bind_to_random_port('tcp://127.0.0.1')
|
||||
dev.connect_out('tcp://127.0.0.1:%i' % port)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello again'
|
||||
req.send(msg)
|
||||
assert msg == self.recv(req)
|
||||
del dev
|
||||
req.close()
|
||||
|
||||
def test_single_socket_forwarder_bind(self):
|
||||
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
|
||||
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
|
||||
req = self.context.socket(zmq.REQ)
|
||||
req.connect('tcp://127.0.0.1:%i' % port)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
req.send(msg)
|
||||
assert msg == self.recv(req)
|
||||
del dev
|
||||
req.close()
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
|
||||
req = self.context.socket(zmq.REQ)
|
||||
req.connect('tcp://127.0.0.1:%i' % port)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello again'
|
||||
req.send(msg)
|
||||
assert msg == self.recv(req)
|
||||
del dev
|
||||
req.close()
|
||||
|
||||
def test_device_bind_to_random_with_args(self):
|
||||
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
ports = []
|
||||
min, max = 5000, 5050
|
||||
ports.extend(
|
||||
[
|
||||
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
|
||||
]
|
||||
)
|
||||
for port in ports:
|
||||
if port < min or port > max:
|
||||
self.fail('Unexpected port number: %i' % port)
|
||||
|
||||
def test_device_bind_to_random_binderror(self):
|
||||
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
try:
|
||||
for i in range(11):
|
||||
dev.bind_in_to_random_port(iface, min_port=10000, max_port=10010)
|
||||
except zmq.ZMQBindError as e:
|
||||
return
|
||||
else:
|
||||
self.fail('Should have failed')
|
||||
|
||||
def test_proxy(self):
|
||||
if zmq.zmq_version_info() < (3, 2):
|
||||
raise SkipTest("Proxies only in libzmq >= 3")
|
||||
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = dev.bind_in_to_random_port(iface)
|
||||
port2 = dev.bind_out_to_random_port(iface)
|
||||
port3 = dev.bind_mon_to_random_port(iface)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
push = self.context.socket(zmq.PUSH)
|
||||
push.connect("%s:%i" % (iface, port))
|
||||
pull = self.context.socket(zmq.PULL)
|
||||
pull.connect("%s:%i" % (iface, port2))
|
||||
mon = self.context.socket(zmq.PULL)
|
||||
mon.connect("%s:%i" % (iface, port3))
|
||||
push.send(msg)
|
||||
self.sockets.extend([push, pull, mon])
|
||||
assert msg == self.recv(pull)
|
||||
assert msg == self.recv(mon)
|
||||
|
||||
def test_proxy_bind_to_random_with_args(self):
|
||||
if zmq.zmq_version_info() < (3, 2):
|
||||
raise SkipTest("Proxies only in libzmq >= 3")
|
||||
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
ports = []
|
||||
min, max = 5000, 5050
|
||||
ports.extend(
|
||||
[
|
||||
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
|
||||
]
|
||||
)
|
||||
for port in ports:
|
||||
if port < min or port > max:
|
||||
self.fail('Unexpected port number: %i' % port)
|
||||
|
||||
|
||||
if have_gevent:
|
||||
import gevent
|
||||
|
||||
import zmq.green
|
||||
|
||||
class TestDeviceGreen(GreenTest, BaseZMQTestCase):
|
||||
def test_green_device(self):
|
||||
rep = self.context.socket(zmq.REP)
|
||||
req = self.context.socket(zmq.REQ)
|
||||
self.sockets.extend([req, rep])
|
||||
port = rep.bind_to_random_port('tcp://127.0.0.1')
|
||||
g = gevent.spawn(zmq.green.device, zmq.QUEUE, rep, rep)
|
||||
req.connect('tcp://127.0.0.1:%i' % port)
|
||||
req.send(b'hi')
|
||||
timeout = gevent.Timeout(3)
|
||||
timeout.start()
|
||||
receiver = gevent.spawn(req.recv)
|
||||
assert receiver.get(2) == b'hi'
|
||||
timeout.cancel()
|
||||
g.kill(block=True)
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
|
||||
class TestDraftSockets(BaseZMQTestCase):
|
||||
def setUp(self):
|
||||
if not zmq.DRAFT_API:
|
||||
pytest.skip("draft api unavailable")
|
||||
super().setUp()
|
||||
|
||||
def test_client_server(self):
|
||||
client, server = self.create_bound_pair(zmq.CLIENT, zmq.SERVER)
|
||||
client.send(b'request')
|
||||
msg = self.recv(server, copy=False)
|
||||
assert msg.routing_id is not None
|
||||
server.send(b'reply', routing_id=msg.routing_id)
|
||||
reply = self.recv(client)
|
||||
assert reply == b'reply'
|
||||
|
||||
def test_radio_dish(self):
|
||||
dish, radio = self.create_bound_pair(zmq.DISH, zmq.RADIO)
|
||||
dish.rcvtimeo = 250
|
||||
group = 'mygroup'
|
||||
dish.join(group)
|
||||
received_count = 0
|
||||
received = set()
|
||||
sent = set()
|
||||
for i in range(10):
|
||||
msg = str(i).encode('ascii')
|
||||
sent.add(msg)
|
||||
radio.send(msg, group=group)
|
||||
try:
|
||||
recvd = dish.recv()
|
||||
except zmq.Again:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
received.add(recvd)
|
||||
received_count += 1
|
||||
# assert that we got *something*
|
||||
assert len(received.intersection(sent)) >= 5
|
||||
@@ -0,0 +1,37 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from threading import Thread
|
||||
|
||||
import zmq
|
||||
from zmq import Again, ContextTerminated, ZMQError, strerror
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
|
||||
class TestZMQError(BaseZMQTestCase):
|
||||
def test_strerror(self):
|
||||
"""test that strerror gets the right type."""
|
||||
for i in range(10):
|
||||
e = strerror(i)
|
||||
assert isinstance(e, str)
|
||||
|
||||
def test_zmqerror(self):
|
||||
for errno in range(10):
|
||||
e = ZMQError(errno)
|
||||
assert e.errno == errno
|
||||
assert str(e) == strerror(errno)
|
||||
|
||||
def test_again(self):
|
||||
s = self.context.socket(zmq.REP)
|
||||
self.assertRaises(Again, s.recv, zmq.NOBLOCK)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, s.recv, zmq.NOBLOCK)
|
||||
s.close()
|
||||
|
||||
def atest_ctxterm(self):
|
||||
s = self.context.socket(zmq.REP)
|
||||
t = Thread(target=self.context.term)
|
||||
t.start()
|
||||
self.assertRaises(ContextTerminated, s.recv, zmq.NOBLOCK)
|
||||
self.assertRaisesErrno(zmq.TERM, s.recv, zmq.NOBLOCK)
|
||||
s.close()
|
||||
t.join()
|
||||
26
asm/venv/lib/python3.11/site-packages/zmq/tests/test_etc.py
Normal file
26
asm/venv/lib/python3.11/site-packages/zmq/tests/test_etc.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) PyZMQ Developers.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
|
||||
only_bundled = mark.skipif(not hasattr(zmq, '_libzmq'), reason="bundled libzmq")
|
||||
|
||||
|
||||
@mark.skipif('zmq.zmq_version_info() < (4, 1)')
|
||||
def test_has():
|
||||
assert not zmq.has('something weird')
|
||||
|
||||
|
||||
@only_bundled
|
||||
def test_has_curve():
|
||||
"""bundled libzmq has curve support"""
|
||||
assert zmq.has('curve')
|
||||
|
||||
|
||||
@only_bundled
|
||||
def test_has_ipc():
|
||||
"""bundled libzmq has ipc support"""
|
||||
assert zmq.has('ipc')
|
||||
34
asm/venv/lib/python3.11/site-packages/zmq/tests/test_ext.py
Normal file
34
asm/venv/lib/python3.11/site-packages/zmq/tests/test_ext.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""tests for extending pyzmq"""
|
||||
|
||||
import zmq
|
||||
|
||||
|
||||
class CustomSocket(zmq.Socket):
|
||||
custom_attr: int
|
||||
|
||||
def __init__(self, context, socket_type, custom_attr: int = 0):
|
||||
super().__init__(context, socket_type)
|
||||
self.custom_attr = custom_attr
|
||||
|
||||
|
||||
class CustomContext(zmq.Context):
|
||||
extra_arg: str
|
||||
_socket_class = CustomSocket
|
||||
|
||||
def __init__(self, extra_arg: str = 'x'):
|
||||
super().__init__()
|
||||
self.extra_arg = extra_arg
|
||||
|
||||
|
||||
def test_custom_context():
|
||||
ctx = CustomContext('s')
|
||||
assert isinstance(ctx, CustomContext)
|
||||
|
||||
assert ctx.extra_arg == 's'
|
||||
s = ctx.socket(zmq.PUSH, custom_attr=10)
|
||||
assert isinstance(s, CustomSocket)
|
||||
assert s.custom_attr == 10
|
||||
assert s.context is ctx
|
||||
assert s.type == zmq.PUSH
|
||||
s.close()
|
||||
ctx.term()
|
||||
354
asm/venv/lib/python3.11/site-packages/zmq/tests/test_future.py
Normal file
354
asm/venv/lib/python3.11/site-packages/zmq/tests/test_future.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# Copyright (c) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
gen = pytest.importorskip('tornado.gen')
|
||||
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
import zmq
|
||||
from zmq.eventloop import future
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
|
||||
class TestFutureSocket(BaseZMQTestCase):
|
||||
Context = future.Context
|
||||
|
||||
def setUp(self):
|
||||
self.loop = IOLoop(make_current=False)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
if self.loop:
|
||||
self.loop.close(all_fds=True)
|
||||
|
||||
def test_socket_class(self):
|
||||
s = self.context.socket(zmq.PUSH)
|
||||
assert isinstance(s, future.Socket)
|
||||
s.close()
|
||||
|
||||
def test_instance_subclass_first(self):
|
||||
actx = self.Context.instance()
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is self.Context
|
||||
|
||||
def test_instance_subclass_second(self):
|
||||
ctx = zmq.Context.instance()
|
||||
actx = self.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is self.Context
|
||||
|
||||
def test_recv_multipart(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_multipart()
|
||||
assert not f.done()
|
||||
await a.send(b"hi")
|
||||
recvd = await f
|
||||
assert recvd == [b'hi']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv()
|
||||
assert not f1.done()
|
||||
assert not f2.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f1.done()
|
||||
assert f1.result() == b'hi'
|
||||
assert recvd == b'there'
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_cancel(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv_multipart()
|
||||
assert f1.cancel()
|
||||
assert f1.done()
|
||||
assert not f2.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f1.cancelled()
|
||||
assert f2.done()
|
||||
assert recvd == [b'hi', b'there']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_recv_timeout(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
b.rcvtimeo = 100
|
||||
f1 = b.recv()
|
||||
b.rcvtimeo = 1000
|
||||
f2 = b.recv_multipart()
|
||||
with pytest.raises(zmq.Again):
|
||||
await f1
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f2.done()
|
||||
assert recvd == [b'hi', b'there']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_send_timeout(self):
|
||||
async def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
s.sndtimeo = 100
|
||||
with pytest.raises(zmq.Again):
|
||||
await s.send(b"not going anywhere")
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_send_noblock(self):
|
||||
async def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
with pytest.raises(zmq.Again):
|
||||
await s.send(b"not going anywhere", flags=zmq.NOBLOCK)
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_send_multipart_noblock(self):
|
||||
async def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
with pytest.raises(zmq.Again):
|
||||
await s.send_multipart([b"not going anywhere"], flags=zmq.NOBLOCK)
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_string(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_string()
|
||||
assert not f.done()
|
||||
msg = 'πøøπ'
|
||||
await a.send_string(msg)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == msg
|
||||
assert recvd == msg
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_json(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
await a.send_json(obj)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == obj
|
||||
assert recvd == obj
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_json_cancelled(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
f.cancel()
|
||||
# cycle eventloop to allow cancel events to fire
|
||||
await gen.sleep(0)
|
||||
obj = dict(a=5)
|
||||
await a.send_json(obj)
|
||||
with pytest.raises(future.CancelledError):
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
# give it a chance to incorrectly consume the event
|
||||
events = await b.poll(timeout=5)
|
||||
assert events
|
||||
await gen.sleep(0)
|
||||
# make sure cancelled recv didn't eat up event
|
||||
recvd = await gen.with_timeout(timedelta(seconds=5), b.recv_json())
|
||||
assert recvd == obj
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_pyobj(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_pyobj()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
await a.send_pyobj(obj)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == obj
|
||||
assert recvd == obj
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_custom_serialize(self):
|
||||
def serialize(msg):
|
||||
frames = []
|
||||
frames.extend(msg.get('identities', []))
|
||||
content = json.dumps(msg['content']).encode('utf8')
|
||||
frames.append(content)
|
||||
return frames
|
||||
|
||||
def deserialize(frames):
|
||||
identities = frames[:-1]
|
||||
content = json.loads(frames[-1].decode('utf8'))
|
||||
return {
|
||||
'identities': identities,
|
||||
'content': content,
|
||||
}
|
||||
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
await a.send_serialized(msg, serialize)
|
||||
recvd = await b.recv_serialized(deserialize)
|
||||
assert recvd['content'] == msg['content']
|
||||
assert recvd['identities']
|
||||
# bounce back, tests identities
|
||||
await b.send_serialized(recvd, serialize)
|
||||
r2 = await a.recv_serialized(deserialize)
|
||||
assert r2['content'] == msg['content']
|
||||
assert not r2['identities']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_custom_serialize_error(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError):
|
||||
await a.send_serialized(json, json.dumps)
|
||||
|
||||
await a.send(b"not json")
|
||||
with pytest.raises(TypeError):
|
||||
await b.recv_serialized(json.loads)
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_poll(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.poll(timeout=0)
|
||||
assert f.done()
|
||||
assert f.result() == 0
|
||||
|
||||
f = b.poll(timeout=1)
|
||||
assert not f.done()
|
||||
evt = await f
|
||||
assert evt == 0
|
||||
|
||||
f = b.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
evt = await f
|
||||
assert evt == zmq.POLLIN
|
||||
recvd = await b.recv_multipart()
|
||||
assert recvd == [b'hi', b'there']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'), reason='Windows unsupported socket type'
|
||||
)
|
||||
def test_poll_base_socket(self):
|
||||
async def test():
|
||||
ctx = zmq.Context()
|
||||
url = 'inproc://test'
|
||||
a = ctx.socket(zmq.PUSH)
|
||||
b = ctx.socket(zmq.PULL)
|
||||
self.sockets.extend([a, b])
|
||||
a.bind(url)
|
||||
b.connect(url)
|
||||
|
||||
poller = future.Poller()
|
||||
poller.register(b, zmq.POLLIN)
|
||||
|
||||
f = poller.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
a.send_multipart([b'hi', b'there'])
|
||||
evt = await f
|
||||
assert evt == [(b, zmq.POLLIN)]
|
||||
recvd = b.recv_multipart()
|
||||
assert recvd == [b'hi', b'there']
|
||||
a.close()
|
||||
b.close()
|
||||
ctx.term()
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_close_all_fds(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
|
||||
async def attach():
|
||||
s._get_loop()
|
||||
|
||||
self.loop.run_sync(attach)
|
||||
self.loop.close(all_fds=True)
|
||||
self.loop = None # avoid second close later
|
||||
assert s.closed
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'),
|
||||
reason='Windows does not support polling on files',
|
||||
)
|
||||
def test_poll_raw(self):
|
||||
async def test():
|
||||
p = future.Poller()
|
||||
# make a pipe
|
||||
r, w = os.pipe()
|
||||
r = os.fdopen(r, 'rb')
|
||||
w = os.fdopen(w, 'wb')
|
||||
|
||||
# POLLOUT
|
||||
p.register(r, zmq.POLLIN)
|
||||
p.register(w, zmq.POLLOUT)
|
||||
evts = await p.poll(timeout=1)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() not in evts
|
||||
assert w.fileno() in evts
|
||||
assert evts[w.fileno()] == zmq.POLLOUT
|
||||
|
||||
# POLLIN
|
||||
p.unregister(w)
|
||||
w.write(b'x')
|
||||
w.flush()
|
||||
evts = await p.poll(timeout=1000)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() in evts
|
||||
assert evts[r.fileno()] == zmq.POLLIN
|
||||
assert r.read(1) == b'x'
|
||||
r.close()
|
||||
w.close()
|
||||
|
||||
self.loop.run_sync(test)
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Test Imports - the quickest test to ensure that we haven't
|
||||
introduced version-incompatible syntax errors.
|
||||
"""
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_toplevel():
|
||||
"""test toplevel import"""
|
||||
import zmq
|
||||
|
||||
|
||||
def test_core():
|
||||
"""test core imports"""
|
||||
from zmq import (
|
||||
Context,
|
||||
Frame,
|
||||
Poller,
|
||||
Socket,
|
||||
constants,
|
||||
device,
|
||||
proxy,
|
||||
pyzmq_version,
|
||||
pyzmq_version_info,
|
||||
zmq_version,
|
||||
zmq_version_info,
|
||||
)
|
||||
|
||||
|
||||
def test_devices():
|
||||
"""test device imports"""
|
||||
import zmq.devices
|
||||
from zmq.devices import basedevice, monitoredqueue, monitoredqueuedevice
|
||||
|
||||
|
||||
def test_log():
|
||||
"""test log imports"""
|
||||
import zmq.log
|
||||
from zmq.log import handlers
|
||||
|
||||
|
||||
def test_eventloop():
|
||||
"""test eventloop imports"""
|
||||
pytest.importorskip("tornado")
|
||||
import zmq.eventloop
|
||||
from zmq.eventloop import ioloop, zmqstream
|
||||
|
||||
|
||||
def test_utils():
|
||||
"""test util imports"""
|
||||
import zmq.utils
|
||||
from zmq.utils import jsonapi, strtypes
|
||||
|
||||
|
||||
def test_ssh():
|
||||
"""test ssh imports"""
|
||||
from zmq.ssh import tunnel
|
||||
|
||||
|
||||
def test_decorators():
|
||||
"""test decorators imports"""
|
||||
from zmq.decorators import context, socket
|
||||
|
||||
|
||||
def test_zmq_all():
|
||||
import zmq
|
||||
|
||||
for name in zmq.__all__:
|
||||
assert hasattr(zmq, name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pkgname", ["zmq", "zmq.green"])
|
||||
@pytest.mark.parametrize(
|
||||
"attr",
|
||||
[
|
||||
"RCVTIMEO",
|
||||
"PUSH",
|
||||
"zmq_version_info",
|
||||
"SocketOption",
|
||||
"device",
|
||||
"Socket",
|
||||
"Context",
|
||||
],
|
||||
)
|
||||
def test_all_exports(pkgname, attr):
|
||||
import zmq
|
||||
|
||||
subpkg = pytest.importorskip(pkgname)
|
||||
for name in zmq.__all__:
|
||||
assert hasattr(subpkg, name)
|
||||
|
||||
assert attr in subpkg.__all__
|
||||
if attr not in ("Socket", "Context", "device"):
|
||||
assert getattr(subpkg, attr) is getattr(zmq, attr)
|
||||
@@ -0,0 +1,33 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
import zmq
|
||||
|
||||
|
||||
class TestIncludes(TestCase):
|
||||
def test_get_includes(self):
|
||||
from os.path import basename
|
||||
|
||||
includes = zmq.get_includes()
|
||||
assert isinstance(includes, list)
|
||||
assert len(includes) >= 2
|
||||
parent = includes[0]
|
||||
assert isinstance(parent, str)
|
||||
utilsdir = includes[1]
|
||||
assert isinstance(utilsdir, str)
|
||||
utils = basename(utilsdir)
|
||||
assert utils == "utils"
|
||||
|
||||
def test_get_library_dirs(self):
|
||||
from os.path import basename
|
||||
|
||||
libdirs = zmq.get_library_dirs()
|
||||
assert isinstance(libdirs, list)
|
||||
assert len(libdirs) == 1
|
||||
parent = libdirs[0]
|
||||
assert isinstance(parent, str)
|
||||
libdir = basename(parent)
|
||||
assert libdir == "zmq"
|
||||
@@ -0,0 +1,33 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import tornado.ioloop
|
||||
except ImportError:
|
||||
_tornado = False
|
||||
else:
|
||||
_tornado = True
|
||||
|
||||
|
||||
def setup():
|
||||
if not _tornado:
|
||||
pytest.skip("requires tornado")
|
||||
|
||||
|
||||
def test_ioloop():
|
||||
# may have been imported before,
|
||||
# can't capture the warning
|
||||
from zmq.eventloop import ioloop
|
||||
|
||||
assert ioloop.IOLoop is tornado.ioloop.IOLoop
|
||||
assert ioloop.ZMQIOLoop is ioloop.IOLoop
|
||||
|
||||
|
||||
def test_ioloop_install():
|
||||
from zmq.eventloop import ioloop
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
ioloop.install()
|
||||
193
asm/venv/lib/python3.11/site-packages/zmq/tests/test_log.py
Normal file
193
asm/venv/lib/python3.11/site-packages/zmq/tests/test_log.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq.log import handlers
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
|
||||
class TestPubLog(BaseZMQTestCase):
|
||||
iface = 'inproc://zmqlog'
|
||||
topic = 'zmq'
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
# print dir(self)
|
||||
logger = logging.getLogger('zmqtest')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
return logger
|
||||
|
||||
def connect_handler(self, topic=None):
|
||||
topic = self.topic if topic is None else topic
|
||||
logger = self.logger
|
||||
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
handler = handlers.PUBHandler(pub)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.root_topic = topic
|
||||
logger.addHandler(handler)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
time.sleep(0.1)
|
||||
return logger, handler, sub
|
||||
|
||||
def test_init_iface(self):
|
||||
logger = self.logger
|
||||
ctx = self.context
|
||||
handler = handlers.PUBHandler(self.iface)
|
||||
assert not handler.ctx is ctx
|
||||
self.sockets.append(handler.socket)
|
||||
# handler.ctx.term()
|
||||
handler = handlers.PUBHandler(self.iface, self.context)
|
||||
self.sockets.append(handler.socket)
|
||||
assert handler.ctx is ctx
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.root_topic = self.topic
|
||||
logger.addHandler(handler)
|
||||
sub = ctx.socket(zmq.SUB)
|
||||
self.sockets.append(sub)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, self.topic.encode())
|
||||
sub.connect(self.iface)
|
||||
import time
|
||||
|
||||
time.sleep(0.25)
|
||||
msg1 = 'message'
|
||||
logger.info(msg1)
|
||||
|
||||
(topic, msg2) = sub.recv_multipart()
|
||||
assert topic == b'zmq.INFO'
|
||||
assert msg2 == (msg1 + "\n").encode("utf8")
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_init_socket(self):
|
||||
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
logger = self.logger
|
||||
handler = handlers.PUBHandler(pub)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.root_topic = self.topic
|
||||
logger.addHandler(handler)
|
||||
|
||||
assert handler.socket is pub
|
||||
assert handler.ctx is pub.context
|
||||
assert handler.ctx is self.context
|
||||
sub.setsockopt(zmq.SUBSCRIBE, self.topic.encode())
|
||||
import time
|
||||
|
||||
time.sleep(0.1)
|
||||
msg1 = 'message'
|
||||
logger.info(msg1)
|
||||
|
||||
(topic, msg2) = sub.recv_multipart()
|
||||
assert topic == b'zmq.INFO'
|
||||
assert msg2 == (msg1 + "\n").encode("utf8")
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_root_topic(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
handler.socket.bind(self.iface)
|
||||
sub2 = sub.context.socket(zmq.SUB)
|
||||
self.sockets.append(sub2)
|
||||
sub2.connect(self.iface)
|
||||
sub2.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
handler.root_topic = b'twoonly'
|
||||
msg1 = 'ignored'
|
||||
logger.info(msg1)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, sub.recv, zmq.NOBLOCK)
|
||||
topic, msg2 = sub2.recv_multipart()
|
||||
assert topic == b'twoonly.INFO'
|
||||
assert msg2 == (msg1 + '\n').encode()
|
||||
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_blank_root_topic(self):
|
||||
logger, handler, sub_everything = self.connect_handler()
|
||||
sub_everything.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
handler.socket.bind(self.iface)
|
||||
sub_only_info = sub_everything.context.socket(zmq.SUB)
|
||||
self.sockets.append(sub_only_info)
|
||||
sub_only_info.connect(self.iface)
|
||||
sub_only_info.setsockopt(zmq.SUBSCRIBE, b'INFO')
|
||||
handler.setRootTopic(b'')
|
||||
msg_debug = 'debug_message'
|
||||
logger.debug(msg_debug)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, sub_only_info.recv, zmq.NOBLOCK)
|
||||
topic, msg_debug_response = sub_everything.recv_multipart()
|
||||
assert topic == b'DEBUG'
|
||||
msg_info = 'info_message'
|
||||
logger.info(msg_info)
|
||||
topic, msg_info_response_everything = sub_everything.recv_multipart()
|
||||
assert topic == b'INFO'
|
||||
topic, msg_info_response_onlyinfo = sub_only_info.recv_multipart()
|
||||
assert topic == b'INFO'
|
||||
assert msg_info_response_everything == msg_info_response_onlyinfo
|
||||
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_unicode_message(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
base_topic = (self.topic + '.INFO').encode()
|
||||
for msg, expected in [
|
||||
('hello', [base_topic, b'hello\n']),
|
||||
('héllo', [base_topic, 'héllo\n'.encode()]),
|
||||
('tøpic::héllo', [base_topic + '.tøpic'.encode(), 'héllo\n'.encode()]),
|
||||
]:
|
||||
logger.info(msg)
|
||||
received = sub.recv_multipart()
|
||||
assert received == expected
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_set_info_formatter_via_property(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
handler.formatters[logging.INFO] = logging.Formatter("%(message)s UNITTEST\n")
|
||||
handler.socket.bind(self.iface)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
|
||||
logger.info('info message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'info message UNITTEST\n'
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_custom_global_formatter(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
formatter = logging.Formatter("UNITTEST %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
handler.socket.bind(self.iface)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
|
||||
logger.info('info message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'UNITTEST info message'
|
||||
logger.debug('debug message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'UNITTEST debug message'
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_custom_debug_formatter(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
formatter = logging.Formatter("UNITTEST DEBUG %(message)s")
|
||||
handler.setFormatter(formatter, logging.DEBUG)
|
||||
handler.socket.bind(self.iface)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
|
||||
logger.info('info message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'info message\n'
|
||||
logger.debug('debug message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'UNITTEST DEBUG debug message'
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_custom_message_type(self):
|
||||
class Message:
|
||||
def __init__(self, msg: str):
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.msg
|
||||
|
||||
logger, handler, sub = self.connect_handler()
|
||||
msg = "hello"
|
||||
logger.info(Message(msg))
|
||||
topic, received = sub.recv_multipart()
|
||||
assert topic == b'zmq.INFO'
|
||||
assert received == b'hello\n'
|
||||
logger.removeHandler(handler)
|
||||
370
asm/venv/lib/python3.11/site-packages/zmq/tests/test_message.py
Normal file
370
asm/venv/lib/python3.11/site-packages/zmq/tests/test_message.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import sys
|
||||
|
||||
try:
|
||||
from sys import getrefcount
|
||||
except ImportError:
|
||||
grc = None
|
||||
else:
|
||||
grc = getrefcount
|
||||
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest, skip_pypy
|
||||
|
||||
# some useful constants:
|
||||
|
||||
x = b'x'
|
||||
|
||||
if grc:
|
||||
rc0 = grc(x)
|
||||
v = memoryview(x)
|
||||
view_rc = grc(x) - rc0
|
||||
|
||||
|
||||
def await_gc(obj, rc):
|
||||
"""wait for refcount on an object to drop to an expected value
|
||||
|
||||
Necessary because of the zero-copy gc thread,
|
||||
which can take some time to receive its DECREF message.
|
||||
"""
|
||||
# count refs for this function
|
||||
if sys.version_info < (3, 11):
|
||||
my_refs = 2
|
||||
else:
|
||||
my_refs = 1
|
||||
for i in range(50):
|
||||
# rc + 2 because of the refs in this function
|
||||
if grc(obj) <= rc + my_refs:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
class TestFrame(BaseZMQTestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
for i in range(3):
|
||||
gc.collect()
|
||||
|
||||
@skip_pypy
|
||||
def test_above_30(self):
|
||||
"""Message above 30 bytes are never copied by 0MQ."""
|
||||
for i in range(5, 16): # 32, 64,..., 65536
|
||||
s = (2**i) * x
|
||||
rc = grc(s)
|
||||
m = zmq.Frame(s, copy=False)
|
||||
assert grc(s) == rc + 2
|
||||
del m
|
||||
await_gc(s, rc)
|
||||
assert grc(s) == rc
|
||||
del s
|
||||
|
||||
def test_str(self):
|
||||
"""Test the str representations of the Frames."""
|
||||
for i in range(16):
|
||||
s = (2**i) * x
|
||||
m = zmq.Frame(s)
|
||||
m_str = str(m)
|
||||
m_str_b = m_str.encode()
|
||||
assert s == m_str_b
|
||||
|
||||
def test_bytes(self):
|
||||
"""Test the Frame.bytes property."""
|
||||
for i in range(1, 16):
|
||||
s = (2**i) * x
|
||||
m = zmq.Frame(s)
|
||||
b = m.bytes
|
||||
assert s == m.bytes
|
||||
if not PYPY:
|
||||
# check that it copies
|
||||
assert b is not s
|
||||
# check that it copies only once
|
||||
assert b is m.bytes
|
||||
|
||||
def test_unicode(self):
|
||||
"""Test the unicode representations of the Frames."""
|
||||
s = 'asdf'
|
||||
self.assertRaises(TypeError, zmq.Frame, s)
|
||||
for i in range(16):
|
||||
s = (2**i) * '§'
|
||||
m = zmq.Frame(s.encode('utf8'))
|
||||
assert s == m.bytes.decode('utf8')
|
||||
|
||||
def test_len(self):
|
||||
"""Test the len of the Frames."""
|
||||
for i in range(16):
|
||||
s = (2**i) * x
|
||||
m = zmq.Frame(s)
|
||||
assert len(s) == len(m)
|
||||
|
||||
@skip_pypy
|
||||
def test_lifecycle1(self):
|
||||
"""Run through a ref counting cycle with a copy."""
|
||||
for i in range(5, 16): # 32, 64,..., 65536
|
||||
s = (2**i) * x
|
||||
rc = rc_0 = grc(s)
|
||||
m = zmq.Frame(s, copy=False)
|
||||
rc += 2
|
||||
assert grc(s) == rc
|
||||
m2 = copy.copy(m)
|
||||
rc += 1
|
||||
assert grc(s) == rc
|
||||
# no increase in refcount for accessing buffer
|
||||
# which references m2 directly
|
||||
buf = m2.buffer
|
||||
assert grc(s) == rc
|
||||
|
||||
assert s == str(m).encode()
|
||||
assert s == bytes(m2)
|
||||
assert s == m.bytes
|
||||
assert s == bytes(buf)
|
||||
# assert s is str(m)
|
||||
# assert s is str(m2)
|
||||
del m2
|
||||
assert grc(s) == rc
|
||||
# buf holds direct reference to m2 which holds
|
||||
del buf
|
||||
rc -= 1
|
||||
assert grc(s) == rc
|
||||
del m
|
||||
rc -= 2
|
||||
await_gc(s, rc)
|
||||
assert grc(s) == rc
|
||||
assert rc == rc_0
|
||||
del s
|
||||
|
||||
@skip_pypy
|
||||
def test_lifecycle2(self):
|
||||
"""Run through a different ref counting cycle with a copy."""
|
||||
for i in range(5, 16): # 32, 64,..., 65536
|
||||
s = (2**i) * x
|
||||
rc = rc_0 = grc(s)
|
||||
m = zmq.Frame(s, copy=False)
|
||||
rc += 2
|
||||
assert grc(s) == rc
|
||||
m2 = copy.copy(m)
|
||||
rc += 1
|
||||
assert grc(s) == rc
|
||||
# no increase in refcount for accessing buffer
|
||||
# which references m directly
|
||||
buf = m.buffer
|
||||
assert grc(s) == rc
|
||||
assert s == str(m).encode()
|
||||
assert s == bytes(m2)
|
||||
assert s == m2.bytes
|
||||
assert s == m.bytes
|
||||
assert s == bytes(buf)
|
||||
# assert s is str(m)
|
||||
# assert s is str(m2)
|
||||
del buf
|
||||
assert grc(s) == rc
|
||||
del m
|
||||
rc -= 1
|
||||
assert grc(s) == rc
|
||||
del m2
|
||||
rc -= 2
|
||||
await_gc(s, rc)
|
||||
assert grc(s) == rc
|
||||
assert rc == rc_0
|
||||
del s
|
||||
|
||||
def test_tracker(self):
|
||||
m = zmq.Frame(b'asdf', copy=False, track=True)
|
||||
assert not m.tracker.done
|
||||
pm = zmq.MessageTracker(m)
|
||||
assert not pm.done
|
||||
del m
|
||||
for i in range(3):
|
||||
gc.collect()
|
||||
for i in range(10):
|
||||
if pm.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert pm.done
|
||||
|
||||
def test_no_tracker(self):
|
||||
m = zmq.Frame(b'asdf', track=False)
|
||||
assert m.tracker == None
|
||||
m2 = copy.copy(m)
|
||||
assert m2.tracker == None
|
||||
self.assertRaises(ValueError, zmq.MessageTracker, m)
|
||||
|
||||
def test_multi_tracker(self):
|
||||
m = zmq.Frame(b'asdf', copy=False, track=True)
|
||||
m2 = zmq.Frame(b'whoda', copy=False, track=True)
|
||||
mt = zmq.MessageTracker(m, m2)
|
||||
assert not m.tracker.done
|
||||
assert not mt.done
|
||||
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
|
||||
del m
|
||||
for i in range(3):
|
||||
gc.collect()
|
||||
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
|
||||
assert not mt.done
|
||||
del m2
|
||||
for i in range(3):
|
||||
gc.collect()
|
||||
assert mt.wait(0.1) is None
|
||||
assert mt.done
|
||||
|
||||
def test_buffer_in(self):
|
||||
"""test using a buffer as input"""
|
||||
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
|
||||
zmq.Frame(memoryview(ins))
|
||||
|
||||
def test_bad_buffer_in(self):
|
||||
"""test using a bad object"""
|
||||
self.assertRaises(TypeError, zmq.Frame, 5)
|
||||
self.assertRaises(TypeError, zmq.Frame, object())
|
||||
|
||||
def test_buffer_out(self):
|
||||
"""receiving buffered output"""
|
||||
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
|
||||
m = zmq.Frame(ins)
|
||||
outb = m.buffer
|
||||
assert isinstance(outb, memoryview)
|
||||
assert outb is m.buffer
|
||||
assert m.buffer is m.buffer
|
||||
|
||||
def test_memoryview_shape(self):
|
||||
"""memoryview shape info"""
|
||||
data = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
|
||||
n = len(data)
|
||||
f = zmq.Frame(data)
|
||||
view1 = f.buffer
|
||||
assert view1.ndim == 1
|
||||
assert view1.shape == (n,)
|
||||
assert view1.tobytes() == data
|
||||
view2 = memoryview(f)
|
||||
assert view2.ndim == 1
|
||||
assert view2.shape == (n,)
|
||||
assert view2.tobytes() == data
|
||||
|
||||
def test_multisend(self):
|
||||
"""ensure that a message remains intact after multiple sends"""
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
s = b"message"
|
||||
m = zmq.Frame(s)
|
||||
assert s == m.bytes
|
||||
|
||||
a.send(m, copy=False)
|
||||
time.sleep(0.1)
|
||||
assert s == m.bytes
|
||||
a.send(m, copy=False)
|
||||
time.sleep(0.1)
|
||||
assert s == m.bytes
|
||||
a.send(m, copy=True)
|
||||
time.sleep(0.1)
|
||||
assert s == m.bytes
|
||||
a.send(m, copy=True)
|
||||
time.sleep(0.1)
|
||||
assert s == m.bytes
|
||||
for i in range(4):
|
||||
r = b.recv()
|
||||
assert s == r
|
||||
assert s == m.bytes
|
||||
|
||||
def test_memoryview(self):
|
||||
"""test messages from memoryview"""
|
||||
s = b'carrotjuice'
|
||||
memoryview(s)
|
||||
m = zmq.Frame(s)
|
||||
buf = m.buffer
|
||||
s2 = buf.tobytes()
|
||||
assert s2 == s
|
||||
assert m.bytes == s
|
||||
|
||||
def test_noncopying_recv(self):
|
||||
"""check for clobbering message buffers"""
|
||||
null = b'\0' * 64
|
||||
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
for i in range(32):
|
||||
# try a few times
|
||||
sb.send(null, copy=False)
|
||||
m = sa.recv(copy=False)
|
||||
mb = m.bytes
|
||||
# buf = memoryview(m)
|
||||
buf = m.buffer
|
||||
del m
|
||||
for i in range(5):
|
||||
ff = b'\xff' * (40 + i * 10)
|
||||
sb.send(ff, copy=False)
|
||||
m2 = sa.recv(copy=False)
|
||||
b = buf.tobytes()
|
||||
assert b == null
|
||||
assert mb == null
|
||||
assert m2.bytes == ff
|
||||
assert type(m2.bytes) is bytes
|
||||
|
||||
def test_noncopying_memoryview(self):
|
||||
"""test non-copying memmoryview messages"""
|
||||
null = b'\0' * 64
|
||||
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
for i in range(32):
|
||||
# try a few times
|
||||
sb.send(memoryview(null), copy=False)
|
||||
m = sa.recv(copy=False)
|
||||
buf = memoryview(m)
|
||||
for i in range(5):
|
||||
ff = b'\xff' * (40 + i * 10)
|
||||
sb.send(memoryview(ff), copy=False)
|
||||
m2 = sa.recv(copy=False)
|
||||
buf2 = memoryview(m2)
|
||||
assert buf.tobytes() == null
|
||||
assert not buf.readonly
|
||||
assert buf2.tobytes() == ff
|
||||
assert not buf2.readonly
|
||||
assert type(buf) is memoryview
|
||||
|
||||
def test_buffer_numpy(self):
|
||||
"""test non-copying numpy array messages"""
|
||||
try:
|
||||
import numpy
|
||||
from numpy.testing import assert_array_equal
|
||||
except ImportError:
|
||||
raise SkipTest("requires numpy")
|
||||
rand = numpy.random.randint
|
||||
shapes = [rand(2, 5) for i in range(5)]
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
dtypes = [int, float, '>i4', 'B']
|
||||
for i in range(1, len(shapes) + 1):
|
||||
shape = shapes[:i]
|
||||
for dt in dtypes:
|
||||
A = numpy.empty(shape, dtype=dt)
|
||||
a.send(A, copy=False)
|
||||
msg = b.recv(copy=False)
|
||||
|
||||
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
|
||||
assert_array_equal(A, B)
|
||||
|
||||
A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')])
|
||||
A['a'] = 1024
|
||||
A['b'] = 1e9
|
||||
A['c'] = 'hello there'
|
||||
a.send(A, copy=False)
|
||||
msg = b.recv(copy=False)
|
||||
|
||||
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
|
||||
assert_array_equal(A, B)
|
||||
|
||||
@skip_pypy
|
||||
def test_frame_more(self):
|
||||
"""test Frame.more attribute"""
|
||||
frame = zmq.Frame(b"hello")
|
||||
assert not frame.more
|
||||
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
sa.send_multipart([b'hi', b'there'])
|
||||
frame = self.recv(sb, copy=False)
|
||||
assert frame.more
|
||||
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
|
||||
assert frame.get(zmq.MORE)
|
||||
frame = self.recv(sb, copy=False)
|
||||
assert not frame.more
|
||||
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
|
||||
assert not frame.get(zmq.MORE)
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from zmq.tests import require_zmq_4
|
||||
from zmq.utils.monitor import recv_monitor_message
|
||||
|
||||
pytestmark = require_zmq_4
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(params=["zmq", "asyncio"])
|
||||
def Context(request, event_loop):
|
||||
if request.param == "asyncio":
|
||||
return zmq.asyncio.Context
|
||||
else:
|
||||
return zmq.Context
|
||||
|
||||
|
||||
async def test_monitor(context, socket):
|
||||
"""Test monitoring interface for sockets."""
|
||||
s_rep = socket(zmq.REP)
|
||||
s_req = socket(zmq.REQ)
|
||||
s_req.bind("tcp://127.0.0.1:6666")
|
||||
# try monitoring the REP socket
|
||||
s_rep.monitor(
|
||||
"inproc://monitor.rep",
|
||||
zmq.EVENT_CONNECT_DELAYED | zmq.EVENT_CONNECTED | zmq.EVENT_MONITOR_STOPPED,
|
||||
)
|
||||
# create listening socket for monitor
|
||||
s_event = socket(zmq.PAIR)
|
||||
s_event.connect("inproc://monitor.rep")
|
||||
s_event.linger = 0
|
||||
# test receive event for connect event
|
||||
s_rep.connect("tcp://127.0.0.1:6666")
|
||||
m = recv_monitor_message(s_event)
|
||||
if isinstance(context, zmq.asyncio.Context):
|
||||
m = await m
|
||||
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
|
||||
assert m['endpoint'] == b"tcp://127.0.0.1:6666"
|
||||
# test receive event for connected event
|
||||
m = recv_monitor_message(s_event)
|
||||
if isinstance(context, zmq.asyncio.Context):
|
||||
m = await m
|
||||
assert m['event'] == zmq.EVENT_CONNECTED
|
||||
assert m['endpoint'] == b"tcp://127.0.0.1:6666"
|
||||
|
||||
# test monitor can be disabled.
|
||||
s_rep.disable_monitor()
|
||||
m = recv_monitor_message(s_event)
|
||||
if isinstance(context, zmq.asyncio.Context):
|
||||
m = await m
|
||||
assert m['event'] == zmq.EVENT_MONITOR_STOPPED
|
||||
|
||||
|
||||
async def test_monitor_repeat(context, socket, sockets):
|
||||
s = socket(zmq.PULL)
|
||||
m = s.get_monitor_socket()
|
||||
sockets.append(m)
|
||||
m2 = s.get_monitor_socket()
|
||||
assert m is m2
|
||||
s.disable_monitor()
|
||||
evt = recv_monitor_message(m)
|
||||
if isinstance(context, zmq.asyncio.Context):
|
||||
evt = await evt
|
||||
assert evt['event'] == zmq.EVENT_MONITOR_STOPPED
|
||||
m.close()
|
||||
s.close()
|
||||
|
||||
|
||||
async def test_monitor_connected(context, socket, sockets):
|
||||
"""Test connected monitoring socket."""
|
||||
s_rep = socket(zmq.REP)
|
||||
s_req = socket(zmq.REQ)
|
||||
s_req.bind("tcp://127.0.0.1:6667")
|
||||
# try monitoring the REP socket
|
||||
# create listening socket for monitor
|
||||
s_event = s_rep.get_monitor_socket()
|
||||
s_event.linger = 0
|
||||
sockets.append(s_event)
|
||||
# test receive event for connect event
|
||||
s_rep.connect("tcp://127.0.0.1:6667")
|
||||
m = recv_monitor_message(s_event)
|
||||
if isinstance(context, zmq.asyncio.Context):
|
||||
m = await m
|
||||
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
|
||||
assert m['endpoint'] == b"tcp://127.0.0.1:6667"
|
||||
# test receive event for connected event
|
||||
m = recv_monitor_message(s_event)
|
||||
if isinstance(context, zmq.asyncio.Context):
|
||||
m = await m
|
||||
assert m['event'] == zmq.EVENT_CONNECTED
|
||||
assert m['endpoint'] == b"tcp://127.0.0.1:6667"
|
||||
235
asm/venv/lib/python3.11/site-packages/zmq/tests/test_monqueue.py
Normal file
235
asm/venv/lib/python3.11/site-packages/zmq/tests/test_monqueue.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq import devices
|
||||
from zmq.tests import PYPY, BaseZMQTestCase
|
||||
|
||||
if PYPY or zmq.zmq_version_info() >= (4, 1):
|
||||
# cleanup of shared Context doesn't work on PyPy
|
||||
# there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052)
|
||||
devices.Device.context_factory = zmq.Context
|
||||
|
||||
|
||||
class TestMonitoredQueue(BaseZMQTestCase):
|
||||
def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'):
|
||||
self.device = devices.ThreadMonitoredQueue(
|
||||
zmq.PAIR, zmq.PAIR, zmq.PUB, in_prefix, out_prefix
|
||||
)
|
||||
alice = self.context.socket(zmq.PAIR)
|
||||
bob = self.context.socket(zmq.PAIR)
|
||||
mon = self.context.socket(zmq.SUB)
|
||||
|
||||
aport = alice.bind_to_random_port('tcp://127.0.0.1')
|
||||
bport = bob.bind_to_random_port('tcp://127.0.0.1')
|
||||
mport = mon.bind_to_random_port('tcp://127.0.0.1')
|
||||
mon.setsockopt(zmq.SUBSCRIBE, mon_sub)
|
||||
|
||||
self.device.connect_in("tcp://127.0.0.1:%i" % aport)
|
||||
self.device.connect_out("tcp://127.0.0.1:%i" % bport)
|
||||
self.device.connect_mon("tcp://127.0.0.1:%i" % mport)
|
||||
self.device.start()
|
||||
time.sleep(0.2)
|
||||
try:
|
||||
# this is currently necessary to ensure no dropped monitor messages
|
||||
# see LIBZMQ-248 for more info
|
||||
mon.recv_multipart(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
self.sockets.extend([alice, bob, mon])
|
||||
return alice, bob, mon
|
||||
|
||||
def teardown_device(self):
|
||||
# spawn term in a background thread
|
||||
for i in range(50):
|
||||
# wait for device._context to be populated
|
||||
context = getattr(self.device, "_context", None)
|
||||
if context is not None:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
if context is not None:
|
||||
t = threading.Thread(target=self.device._context.term, daemon=True)
|
||||
t.start()
|
||||
|
||||
for socket in self.sockets:
|
||||
socket.close()
|
||||
|
||||
if context is not None:
|
||||
t.join(timeout=5)
|
||||
|
||||
self.device.join(timeout=5)
|
||||
|
||||
def test_reply(self):
|
||||
alice, bob, mon = self.build_device()
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
self.teardown_device()
|
||||
|
||||
def test_queue(self):
|
||||
alice, bob, mon = self.build_device()
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices2 == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices3 == bobs
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
self.teardown_device()
|
||||
|
||||
def test_monitor(self):
|
||||
alice, bob, mon = self.build_device()
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'in'] + bobs == mons
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices2 == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices3 == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'in'] + alices2 == mons
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'in'] + alices3 == mons
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'out'] + bobs == mons
|
||||
self.teardown_device()
|
||||
|
||||
def test_prefix(self):
|
||||
alice, bob, mon = self.build_device(b"", b'foo', b'bar')
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'foo'] + bobs == mons
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices2 == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices3 == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'foo'] + alices2 == mons
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'foo'] + alices3 == mons
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'bar'] + bobs == mons
|
||||
self.teardown_device()
|
||||
|
||||
def test_monitor_subscribe(self):
|
||||
alice, bob, mon = self.build_device(b"out")
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices2 == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices3 == bobs
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'out'] + bobs == mons
|
||||
self.teardown_device()
|
||||
|
||||
def test_router_router(self):
|
||||
"""test router-router MQ devices"""
|
||||
dev = devices.ThreadMonitoredQueue(
|
||||
zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out'
|
||||
)
|
||||
self.device = dev
|
||||
dev.setsockopt_in(zmq.LINGER, 0)
|
||||
dev.setsockopt_out(zmq.LINGER, 0)
|
||||
dev.setsockopt_mon(zmq.LINGER, 0)
|
||||
|
||||
porta = dev.bind_in_to_random_port('tcp://127.0.0.1')
|
||||
portb = dev.bind_out_to_random_port('tcp://127.0.0.1')
|
||||
a = self.context.socket(zmq.DEALER)
|
||||
a.identity = b'a'
|
||||
b = self.context.socket(zmq.DEALER)
|
||||
b.identity = b'b'
|
||||
self.sockets.extend([a, b])
|
||||
|
||||
a.connect('tcp://127.0.0.1:%i' % porta)
|
||||
b.connect('tcp://127.0.0.1:%i' % portb)
|
||||
dev.start()
|
||||
time.sleep(1)
|
||||
if zmq.zmq_version_info() >= (3, 1, 0):
|
||||
# flush erroneous poll state, due to LIBZMQ-280
|
||||
ping_msg = [b'ping', b'pong']
|
||||
for s in (a, b):
|
||||
s.send_multipart(ping_msg)
|
||||
try:
|
||||
s.recv(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
msg = [b'hello', b'there']
|
||||
a.send_multipart([b'b'] + msg)
|
||||
bmsg = self.recv_multipart(b)
|
||||
assert bmsg == [b'a'] + msg
|
||||
b.send_multipart(bmsg)
|
||||
amsg = self.recv_multipart(a)
|
||||
assert amsg == [b'b'] + msg
|
||||
self.teardown_device()
|
||||
|
||||
def test_default_mq_args(self):
|
||||
self.device = dev = devices.ThreadMonitoredQueue(
|
||||
zmq.ROUTER, zmq.DEALER, zmq.PUB
|
||||
)
|
||||
dev.setsockopt_in(zmq.LINGER, 0)
|
||||
dev.setsockopt_out(zmq.LINGER, 0)
|
||||
dev.setsockopt_mon(zmq.LINGER, 0)
|
||||
# this will raise if default args are wrong
|
||||
dev.start()
|
||||
self.teardown_device()
|
||||
|
||||
def test_mq_check_prefix(self):
|
||||
ins = self.context.socket(zmq.ROUTER)
|
||||
outs = self.context.socket(zmq.DEALER)
|
||||
mons = self.context.socket(zmq.PUB)
|
||||
self.sockets.extend([ins, outs, mons])
|
||||
|
||||
ins = 'in'
|
||||
outs = 'out'
|
||||
self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons)
|
||||
@@ -0,0 +1,34 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
|
||||
|
||||
|
||||
class TestMultipart(BaseZMQTestCase):
|
||||
def test_router_dealer(self):
|
||||
router, dealer = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
|
||||
|
||||
msg1 = b'message1'
|
||||
dealer.send(msg1)
|
||||
self.recv(router)
|
||||
more = router.rcvmore
|
||||
assert more == True
|
||||
msg2 = self.recv(router)
|
||||
assert msg1 == msg2
|
||||
more = router.rcvmore
|
||||
assert more == False
|
||||
|
||||
def test_basic_multipart(self):
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
msg = [b'hi', b'there', b'b']
|
||||
a.send_multipart(msg)
|
||||
recvd = b.recv_multipart()
|
||||
assert msg == recvd
|
||||
|
||||
|
||||
if have_gevent:
|
||||
|
||||
class TestMultipartGreen(GreenTest, TestMultipart):
|
||||
pass
|
||||
73
asm/venv/lib/python3.11/site-packages/zmq/tests/test_mypy.py
Normal file
73
asm/venv/lib/python3.11/site-packages/zmq/tests/test_mypy.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Test our typing with mypy
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from subprocess import PIPE, STDOUT, Popen
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
|
||||
pytest.importorskip("mypy")
|
||||
|
||||
zmq_dir = os.path.dirname(zmq.__file__)
|
||||
|
||||
|
||||
def resolve_repo_dir(path):
|
||||
"""Resolve a dir in the repo
|
||||
|
||||
Resolved relative to zmq dir
|
||||
|
||||
fallback on CWD (e.g. test run from repo, zmq installed, not -e)
|
||||
"""
|
||||
resolved_path = os.path.join(os.path.dirname(zmq_dir), path)
|
||||
|
||||
# fallback on CWD
|
||||
if not os.path.exists(resolved_path):
|
||||
resolved_path = path
|
||||
return resolved_path
|
||||
|
||||
|
||||
examples_dir = resolve_repo_dir("examples")
|
||||
mypy_dir = resolve_repo_dir("mypy_tests")
|
||||
|
||||
|
||||
def run_mypy(*mypy_args):
|
||||
"""Run mypy for a path
|
||||
|
||||
Captures output and reports it on errors
|
||||
"""
|
||||
p = Popen(
|
||||
[sys.executable, "-m", "mypy"] + list(mypy_args), stdout=PIPE, stderr=STDOUT
|
||||
)
|
||||
o, _ = p.communicate()
|
||||
out = o.decode("utf8", "replace")
|
||||
print(out)
|
||||
assert p.returncode == 0, out
|
||||
|
||||
|
||||
if os.path.exists(examples_dir):
|
||||
examples = [
|
||||
d
|
||||
for d in os.listdir(examples_dir)
|
||||
if os.path.isdir(os.path.join(examples_dir, d))
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.path.exists(examples_dir), reason="only test from examples directory"
|
||||
)
|
||||
@pytest.mark.parametrize("example", examples)
|
||||
def test_mypy_example(example):
|
||||
example_dir = os.path.join(examples_dir, example)
|
||||
run_mypy("--disallow-untyped-calls", example_dir)
|
||||
|
||||
|
||||
if os.path.exists(mypy_dir):
|
||||
mypy_tests = [p for p in os.listdir(mypy_dir) if p.endswith(".py")]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("filename", mypy_tests)
|
||||
def test_mypy(filename):
|
||||
run_mypy("--disallow-untyped-calls", os.path.join(mypy_dir, filename))
|
||||
52
asm/venv/lib/python3.11/site-packages/zmq/tests/test_pair.py
Normal file
52
asm/venv/lib/python3.11/site-packages/zmq/tests/test_pair.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
|
||||
|
||||
x = b' '
|
||||
|
||||
|
||||
class TestPair(BaseZMQTestCase):
|
||||
def test_basic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
msg1 = b'message1'
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_multiple(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
for i in range(10):
|
||||
msg = i * x
|
||||
s1.send(msg)
|
||||
|
||||
for i in range(10):
|
||||
msg = i * x
|
||||
s2.send(msg)
|
||||
|
||||
for i in range(10):
|
||||
msg = s1.recv()
|
||||
assert msg == i * x
|
||||
|
||||
for i in range(10):
|
||||
msg = s2.recv()
|
||||
assert msg == i * x
|
||||
|
||||
def test_json(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
o = dict(a=10, b=list(range(10)))
|
||||
self.ping_pong_json(s1, s2, o)
|
||||
|
||||
def test_pyobj(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
o = dict(a=10, b=range(10))
|
||||
self.ping_pong_pyobj(s1, s2, o)
|
||||
|
||||
|
||||
if have_gevent:
|
||||
|
||||
class TestReqRepGreen(GreenTest, TestPair):
|
||||
pass
|
||||
238
asm/venv/lib/python3.11/site-packages/zmq/tests/test_poll.py
Normal file
238
asm/venv/lib/python3.11/site-packages/zmq/tests/test_poll.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import GreenTest, PollZMQTestCase, have_gevent
|
||||
|
||||
|
||||
def wait():
|
||||
time.sleep(0.25)
|
||||
|
||||
|
||||
class TestPoll(PollZMQTestCase):
|
||||
Poller = zmq.Poller
|
||||
|
||||
def test_pair(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
|
||||
poller.register(s2, zmq.POLLIN | zmq.POLLOUT)
|
||||
# Poll result should contain both sockets
|
||||
socks = dict(poller.poll())
|
||||
# Now make sure that both are send ready.
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
assert socks[s2] == zmq.POLLOUT
|
||||
# Now do a send on both, wait and test for zmq.POLLOUT|zmq.POLLIN
|
||||
s1.send(b'msg1')
|
||||
s2.send(b'msg2')
|
||||
wait()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT | zmq.POLLIN
|
||||
assert socks[s2] == zmq.POLLOUT | zmq.POLLIN
|
||||
# Make sure that both are in POLLOUT after recv.
|
||||
s1.recv()
|
||||
s2.recv()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
assert socks[s2] == zmq.POLLOUT
|
||||
|
||||
poller.unregister(s1)
|
||||
poller.unregister(s2)
|
||||
|
||||
def test_reqrep(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REP, zmq.REQ)
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
|
||||
poller.register(s2, zmq.POLLIN | zmq.POLLOUT)
|
||||
|
||||
# Make sure that s1 is in state 0 and s2 is in POLLOUT
|
||||
socks = dict(poller.poll())
|
||||
assert s1 not in socks
|
||||
assert socks[s2] == zmq.POLLOUT
|
||||
|
||||
# Make sure that s2 goes immediately into state 0 after send.
|
||||
s2.send(b'msg1')
|
||||
socks = dict(poller.poll())
|
||||
assert s2 not in socks
|
||||
|
||||
# Make sure that s1 goes into POLLIN state after a time.sleep().
|
||||
time.sleep(0.5)
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLIN
|
||||
|
||||
# Make sure that s1 goes into POLLOUT after recv.
|
||||
s1.recv()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
|
||||
# Make sure s1 goes into state 0 after send.
|
||||
s1.send(b'msg2')
|
||||
socks = dict(poller.poll())
|
||||
assert s1 not in socks
|
||||
|
||||
# Wait and then see that s2 is in POLLIN.
|
||||
time.sleep(0.5)
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s2] == zmq.POLLIN
|
||||
|
||||
# Make sure that s2 is in POLLOUT after recv.
|
||||
s2.recv()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s2] == zmq.POLLOUT
|
||||
|
||||
poller.unregister(s1)
|
||||
poller.unregister(s2)
|
||||
|
||||
def test_no_events(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
|
||||
poller.register(s2, 0)
|
||||
assert s1 in poller
|
||||
assert s2 not in poller
|
||||
poller.register(s1, 0)
|
||||
assert s1 not in poller
|
||||
|
||||
def test_pubsub(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
s2.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
|
||||
poller.register(s2, zmq.POLLIN)
|
||||
|
||||
# Now make sure that both are send ready.
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
assert s2 not in socks
|
||||
# Make sure that s1 stays in POLLOUT after a send.
|
||||
s1.send(b'msg1')
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
|
||||
# Make sure that s2 is POLLIN after waiting.
|
||||
wait()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s2] == zmq.POLLIN
|
||||
|
||||
# Make sure that s2 goes into 0 after recv.
|
||||
s2.recv()
|
||||
socks = dict(poller.poll())
|
||||
assert s2 not in socks
|
||||
|
||||
poller.unregister(s1)
|
||||
poller.unregister(s2)
|
||||
|
||||
@mark.skipif(sys.platform.startswith('win'), reason='Windows')
|
||||
def test_raw(self):
|
||||
r, w = os.pipe()
|
||||
r = os.fdopen(r, 'rb')
|
||||
w = os.fdopen(w, 'wb')
|
||||
p = self.Poller()
|
||||
p.register(r, zmq.POLLIN)
|
||||
socks = dict(p.poll(1))
|
||||
assert socks == {}
|
||||
w.write(b'x')
|
||||
w.flush()
|
||||
socks = dict(p.poll(1))
|
||||
assert socks == {r.fileno(): zmq.POLLIN}
|
||||
w.close()
|
||||
r.close()
|
||||
|
||||
@mark.flaky(reruns=3)
|
||||
def test_timeout(self):
|
||||
"""make sure Poller.poll timeout has the right units (milliseconds)."""
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN)
|
||||
tic = time.perf_counter()
|
||||
poller.poll(0.005)
|
||||
toc = time.perf_counter()
|
||||
toc - tic < 0.5
|
||||
tic = time.perf_counter()
|
||||
poller.poll(50)
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 0.5
|
||||
assert toc - tic > 0.01
|
||||
tic = time.perf_counter()
|
||||
poller.poll(500)
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
||||
assert toc - tic > 0.1
|
||||
|
||||
|
||||
class TestSelect(PollZMQTestCase):
|
||||
def test_pair(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
rlist, wlist, xlist = zmq.select([s1, s2], [s1, s2], [s1, s2])
|
||||
assert s1 in wlist
|
||||
assert s2 in wlist
|
||||
assert s1 not in rlist
|
||||
assert s2 not in rlist
|
||||
|
||||
@mark.flaky(reruns=3)
|
||||
def test_timeout(self):
|
||||
"""make sure select timeout has the right units (seconds)."""
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
tic = time.perf_counter()
|
||||
r, w, x = zmq.select([s1, s2], [], [], 0.005)
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
||||
assert toc - tic > 0.001
|
||||
tic = time.perf_counter()
|
||||
r, w, x = zmq.select([s1, s2], [], [], 0.25)
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
||||
assert toc - tic > 0.1
|
||||
|
||||
|
||||
if have_gevent:
|
||||
import gevent
|
||||
|
||||
from zmq import green as gzmq
|
||||
|
||||
class TestPollGreen(GreenTest, TestPoll):
|
||||
Poller = gzmq.Poller
|
||||
|
||||
def test_wakeup(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
poller = self.Poller()
|
||||
poller.register(s2, zmq.POLLIN)
|
||||
|
||||
tic = time.perf_counter()
|
||||
r = gevent.spawn(lambda: poller.poll(10000))
|
||||
s = gevent.spawn(lambda: s1.send(b'msg1'))
|
||||
r.join()
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
||||
|
||||
def test_socket_poll(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
tic = time.perf_counter()
|
||||
r = gevent.spawn(lambda: s2.poll(10000))
|
||||
s = gevent.spawn(lambda: s1.send(b'msg1'))
|
||||
r.join()
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
||||
@@ -0,0 +1,95 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import struct
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq import devices
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest
|
||||
|
||||
if PYPY:
|
||||
# cleanup of shared Context doesn't work on PyPy
|
||||
devices.Device.context_factory = zmq.Context
|
||||
|
||||
|
||||
class TestProxySteerable(BaseZMQTestCase):
|
||||
def test_proxy_steerable(self):
|
||||
if zmq.zmq_version_info() < (4, 1):
|
||||
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
|
||||
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = dev.bind_in_to_random_port(iface)
|
||||
port2 = dev.bind_out_to_random_port(iface)
|
||||
port3 = dev.bind_mon_to_random_port(iface)
|
||||
port4 = dev.bind_ctrl_to_random_port(iface)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
push = self.context.socket(zmq.PUSH)
|
||||
push.connect("%s:%i" % (iface, port))
|
||||
pull = self.context.socket(zmq.PULL)
|
||||
pull.connect("%s:%i" % (iface, port2))
|
||||
mon = self.context.socket(zmq.PULL)
|
||||
mon.connect("%s:%i" % (iface, port3))
|
||||
ctrl = self.context.socket(zmq.PAIR)
|
||||
ctrl.connect("%s:%i" % (iface, port4))
|
||||
push.send(msg)
|
||||
self.sockets.extend([push, pull, mon, ctrl])
|
||||
assert msg == self.recv(pull)
|
||||
assert msg == self.recv(mon)
|
||||
ctrl.send(b'TERMINATE')
|
||||
dev.join()
|
||||
|
||||
def test_proxy_steerable_bind_to_random_with_args(self):
|
||||
if zmq.zmq_version_info() < (4, 1):
|
||||
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
|
||||
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
ports = []
|
||||
min, max = 5000, 5050
|
||||
ports.extend(
|
||||
[
|
||||
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_ctrl_to_random_port(iface, min_port=min, max_port=max),
|
||||
]
|
||||
)
|
||||
for port in ports:
|
||||
if port < min or port > max:
|
||||
self.fail('Unexpected port number: %i' % port)
|
||||
|
||||
def test_proxy_steerable_statistics(self):
|
||||
if zmq.zmq_version_info() < (4, 3):
|
||||
raise SkipTest("STATISTICS only in libzmq >= 4.3")
|
||||
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = dev.bind_in_to_random_port(iface)
|
||||
port2 = dev.bind_out_to_random_port(iface)
|
||||
port3 = dev.bind_mon_to_random_port(iface)
|
||||
port4 = dev.bind_ctrl_to_random_port(iface)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
push = self.context.socket(zmq.PUSH)
|
||||
push.connect("%s:%i" % (iface, port))
|
||||
pull = self.context.socket(zmq.PULL)
|
||||
pull.connect("%s:%i" % (iface, port2))
|
||||
mon = self.context.socket(zmq.PULL)
|
||||
mon.connect("%s:%i" % (iface, port3))
|
||||
ctrl = self.context.socket(zmq.PAIR)
|
||||
ctrl.connect("%s:%i" % (iface, port4))
|
||||
push.send(msg)
|
||||
self.sockets.extend([push, pull, mon, ctrl])
|
||||
assert msg == self.recv(pull)
|
||||
assert msg == self.recv(mon)
|
||||
ctrl.send(b'STATISTICS')
|
||||
stats = self.recv_multipart(ctrl)
|
||||
stats_int = [struct.unpack("=Q", x)[0] for x in stats]
|
||||
assert 1 == stats_int[0]
|
||||
assert len(msg) == stats_int[1]
|
||||
assert 1 == stats_int[6]
|
||||
assert len(msg) == stats_int[7]
|
||||
ctrl.send(b'TERMINATE')
|
||||
dev.join()
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
|
||||
|
||||
|
||||
class TestPubSub(BaseZMQTestCase):
|
||||
pass
|
||||
|
||||
# We are disabling this test while an issue is being resolved.
|
||||
def test_basic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
s2.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
time.sleep(0.1)
|
||||
msg1 = b'message'
|
||||
s1.send(msg1)
|
||||
msg2 = s2.recv() # This is blocking!
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_topic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
s2.setsockopt(zmq.SUBSCRIBE, b'x')
|
||||
time.sleep(0.1)
|
||||
msg1 = b'message'
|
||||
s1.send(msg1)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, s2.recv, zmq.NOBLOCK)
|
||||
msg1 = b'xmessage'
|
||||
s1.send(msg1)
|
||||
msg2 = s2.recv()
|
||||
assert msg1 == msg2
|
||||
|
||||
|
||||
if have_gevent:
|
||||
|
||||
class TestPubSubGreen(GreenTest, TestPubSub):
|
||||
pass
|
||||
@@ -0,0 +1,61 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
|
||||
|
||||
|
||||
class TestReqRep(BaseZMQTestCase):
|
||||
def test_basic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
|
||||
msg1 = b'message 1'
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_multiple(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
|
||||
for i in range(10):
|
||||
msg1 = i * b' '
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_bad_send_recv(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
|
||||
if zmq.zmq_version() != '2.1.8':
|
||||
# this doesn't work on 2.1.8
|
||||
for copy in (True, False):
|
||||
self.assertRaisesErrno(zmq.EFSM, s1.recv, copy=copy)
|
||||
self.assertRaisesErrno(zmq.EFSM, s2.send, b'asdf', copy=copy)
|
||||
|
||||
# I have to have this or we die on an Abort trap.
|
||||
msg1 = b'asdf'
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_json(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
o = dict(a=10, b=list(range(10)))
|
||||
self.ping_pong_json(s1, s2, o)
|
||||
|
||||
def test_pyobj(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
o = dict(a=10, b=range(10))
|
||||
self.ping_pong_pyobj(s1, s2, o)
|
||||
|
||||
def test_large_msg(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
msg1 = 10000 * b'X'
|
||||
|
||||
for i in range(10):
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
|
||||
if have_gevent:
|
||||
|
||||
class TestReqRepGreen(GreenTest, TestReqRep):
|
||||
pass
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import signal
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest
|
||||
|
||||
# Partially based on EINTRBaseTest from CPython 3.5 eintr_tester
|
||||
|
||||
|
||||
class TestEINTRSysCall(BaseZMQTestCase):
|
||||
"""Base class for EINTR tests."""
|
||||
|
||||
# delay for initial signal delivery
|
||||
signal_delay = 0.1
|
||||
# timeout for tests. Must be > signal_delay
|
||||
timeout = 0.25
|
||||
timeout_ms = int(timeout * 1e3)
|
||||
|
||||
def alarm(self, t=None):
|
||||
"""start a timer to fire only once
|
||||
|
||||
like signal.alarm, but with better resolution than integer seconds.
|
||||
"""
|
||||
if not hasattr(signal, 'setitimer'):
|
||||
raise SkipTest('EINTR tests require setitimer')
|
||||
if t is None:
|
||||
t = self.signal_delay
|
||||
self.timer_fired = False
|
||||
self.orig_handler = signal.signal(signal.SIGALRM, self.stop_timer)
|
||||
# signal_period ignored, since only one timer event is allowed to fire
|
||||
signal.setitimer(signal.ITIMER_REAL, t, 1000)
|
||||
|
||||
def stop_timer(self, *args):
|
||||
self.timer_fired = True
|
||||
signal.setitimer(signal.ITIMER_REAL, 0, 0)
|
||||
signal.signal(signal.SIGALRM, self.orig_handler)
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_retry_recv(self):
|
||||
pull = self.socket(zmq.PULL)
|
||||
pull.rcvtimeo = self.timeout_ms
|
||||
self.alarm()
|
||||
self.assertRaises(zmq.Again, pull.recv)
|
||||
assert self.timer_fired
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_retry_send(self):
|
||||
push = self.socket(zmq.PUSH)
|
||||
push.sndtimeo = self.timeout_ms
|
||||
self.alarm()
|
||||
self.assertRaises(zmq.Again, push.send, b'buf')
|
||||
assert self.timer_fired
|
||||
|
||||
@mark.flaky(reruns=3)
|
||||
def test_retry_poll(self):
|
||||
x, y = self.create_bound_pair()
|
||||
poller = zmq.Poller()
|
||||
poller.register(x, zmq.POLLIN)
|
||||
self.alarm()
|
||||
|
||||
def send():
|
||||
time.sleep(2 * self.signal_delay)
|
||||
y.send(b'ping')
|
||||
|
||||
t = Thread(target=send)
|
||||
t.start()
|
||||
evts = dict(poller.poll(2 * self.timeout_ms))
|
||||
t.join()
|
||||
assert x in evts
|
||||
assert self.timer_fired
|
||||
x.recv()
|
||||
|
||||
def test_retry_term(self):
|
||||
push = self.socket(zmq.PUSH)
|
||||
push.linger = self.timeout_ms
|
||||
push.connect('tcp://127.0.0.1:5555')
|
||||
push.send(b'ping')
|
||||
time.sleep(0.1)
|
||||
self.alarm()
|
||||
self.context.destroy()
|
||||
assert self.timer_fired
|
||||
assert self.context.closed
|
||||
|
||||
def test_retry_getsockopt(self):
|
||||
raise SkipTest("TODO: find a way to interrupt getsockopt")
|
||||
|
||||
def test_retry_setsockopt(self):
|
||||
raise SkipTest("TODO: find a way to interrupt setsockopt")
|
||||
238
asm/venv/lib/python3.11/site-packages/zmq/tests/test_security.py
Normal file
238
asm/venv/lib/python3.11/site-packages/zmq/tests/test_security.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Test libzmq security (libzmq >= 3.3.0)"""
|
||||
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import zmq
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest
|
||||
from zmq.utils import z85
|
||||
|
||||
USER = b"admin"
|
||||
PASS = b"password"
|
||||
|
||||
|
||||
class TestSecurity(BaseZMQTestCase):
|
||||
def setUp(self):
|
||||
if zmq.zmq_version_info() < (4, 0):
|
||||
raise SkipTest("security is new in libzmq 4.0")
|
||||
try:
|
||||
zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("security requires libzmq to be built with CURVE support")
|
||||
super().setUp()
|
||||
|
||||
def zap_handler(self):
|
||||
socket = self.context.socket(zmq.REP)
|
||||
socket.bind("inproc://zeromq.zap.01")
|
||||
try:
|
||||
msg = self.recv_multipart(socket)
|
||||
|
||||
version, sequence, domain, address, identity, mechanism = msg[:6]
|
||||
if mechanism == b'PLAIN':
|
||||
username, password = msg[6:]
|
||||
elif mechanism == b'CURVE':
|
||||
msg[6]
|
||||
|
||||
assert version == b"1.0"
|
||||
assert identity == b"IDENT"
|
||||
reply = [version, sequence]
|
||||
if (
|
||||
mechanism == b'CURVE'
|
||||
or (mechanism == b'PLAIN' and username == USER and password == PASS)
|
||||
or (mechanism == b'NULL')
|
||||
):
|
||||
reply.extend(
|
||||
[
|
||||
b"200",
|
||||
b"OK",
|
||||
b"anonymous",
|
||||
b"\5Hello\0\0\0\5World",
|
||||
]
|
||||
)
|
||||
else:
|
||||
reply.extend(
|
||||
[
|
||||
b"400",
|
||||
b"Invalid username or password",
|
||||
b"",
|
||||
b"",
|
||||
]
|
||||
)
|
||||
socket.send_multipart(reply)
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zap(self):
|
||||
self.start_zap()
|
||||
time.sleep(0.5) # allow time for the Thread to start
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.stop_zap()
|
||||
|
||||
def start_zap(self):
|
||||
self.zap_thread = Thread(target=self.zap_handler)
|
||||
self.zap_thread.start()
|
||||
|
||||
def stop_zap(self):
|
||||
self.zap_thread.join()
|
||||
|
||||
def bounce(self, server, client, test_metadata=True):
|
||||
msg = [os.urandom(64), os.urandom(64)]
|
||||
client.send_multipart(msg)
|
||||
frames = self.recv_multipart(server, copy=False)
|
||||
recvd = list(map(lambda x: x.bytes, frames))
|
||||
|
||||
try:
|
||||
if test_metadata and not PYPY:
|
||||
for frame in frames:
|
||||
assert frame.get('User-Id') == 'anonymous'
|
||||
assert frame.get('Hello') == 'World'
|
||||
assert frame['Socket-Type'] == 'DEALER'
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
|
||||
assert recvd == msg
|
||||
server.send_multipart(recvd)
|
||||
msg2 = self.recv_multipart(client)
|
||||
assert msg2 == msg
|
||||
|
||||
def test_null(self):
|
||||
"""test NULL (default) security"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
client = self.socket(zmq.DEALER)
|
||||
assert client.MECHANISM == zmq.NULL
|
||||
assert server.mechanism == zmq.NULL
|
||||
assert client.plain_server == 0
|
||||
assert server.plain_server == 0
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
self.bounce(server, client, False)
|
||||
|
||||
def test_plain(self):
|
||||
"""test PLAIN authentication"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
server.identity = b'IDENT'
|
||||
client = self.socket(zmq.DEALER)
|
||||
assert client.plain_username == b''
|
||||
assert client.plain_password == b''
|
||||
client.plain_username = USER
|
||||
client.plain_password = PASS
|
||||
assert client.getsockopt(zmq.PLAIN_USERNAME) == USER
|
||||
assert client.getsockopt(zmq.PLAIN_PASSWORD) == PASS
|
||||
assert client.plain_server == 0
|
||||
assert server.plain_server == 0
|
||||
server.plain_server = True
|
||||
assert server.mechanism == zmq.PLAIN
|
||||
assert client.mechanism == zmq.PLAIN
|
||||
|
||||
assert not client.plain_server
|
||||
assert server.plain_server
|
||||
|
||||
with self.zap():
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
self.bounce(server, client)
|
||||
|
||||
def skip_plain_inauth(self):
|
||||
"""test PLAIN failed authentication"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
server.identity = b'IDENT'
|
||||
client = self.socket(zmq.DEALER)
|
||||
self.sockets.extend([server, client])
|
||||
client.plain_username = USER
|
||||
client.plain_password = b'incorrect'
|
||||
server.plain_server = True
|
||||
assert server.mechanism == zmq.PLAIN
|
||||
assert client.mechanism == zmq.PLAIN
|
||||
|
||||
with self.zap():
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
client.send(b'ping')
|
||||
server.rcvtimeo = 250
|
||||
self.assertRaisesErrno(zmq.EAGAIN, server.recv)
|
||||
|
||||
def test_keypair(self):
|
||||
"""test curve_keypair"""
|
||||
try:
|
||||
public, secret = zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("CURVE unsupported")
|
||||
|
||||
assert type(secret) == bytes
|
||||
assert type(public) == bytes
|
||||
assert len(secret) == 40
|
||||
assert len(public) == 40
|
||||
|
||||
# verify that it is indeed Z85
|
||||
bsecret, bpublic = (z85.decode(key) for key in (public, secret))
|
||||
assert type(bsecret) == bytes
|
||||
assert type(bpublic) == bytes
|
||||
assert len(bsecret) == 32
|
||||
assert len(bpublic) == 32
|
||||
|
||||
def test_curve_public(self):
|
||||
"""test curve_public"""
|
||||
try:
|
||||
public, secret = zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("CURVE unsupported")
|
||||
if zmq.zmq_version_info() < (4, 2):
|
||||
raise SkipTest("curve_public is new in libzmq 4.2")
|
||||
|
||||
derived_public = zmq.curve_public(secret)
|
||||
|
||||
assert type(derived_public) == bytes
|
||||
assert len(derived_public) == 40
|
||||
|
||||
# verify that it is indeed Z85
|
||||
bpublic = z85.decode(derived_public)
|
||||
assert type(bpublic) == bytes
|
||||
assert len(bpublic) == 32
|
||||
|
||||
# verify that it is equal to the known public key
|
||||
assert derived_public == public
|
||||
|
||||
def test_curve(self):
|
||||
"""test CURVE encryption"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
server.identity = b'IDENT'
|
||||
client = self.socket(zmq.DEALER)
|
||||
self.sockets.extend([server, client])
|
||||
try:
|
||||
server.curve_server = True
|
||||
except zmq.ZMQError as e:
|
||||
# will raise EINVAL if no CURVE support
|
||||
if e.errno == zmq.EINVAL:
|
||||
raise SkipTest("CURVE unsupported")
|
||||
|
||||
server_public, server_secret = zmq.curve_keypair()
|
||||
client_public, client_secret = zmq.curve_keypair()
|
||||
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_publickey = server_public
|
||||
client.curve_serverkey = server_public
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
|
||||
assert server.mechanism == zmq.CURVE
|
||||
assert client.mechanism == zmq.CURVE
|
||||
|
||||
assert server.get(zmq.CURVE_SERVER) == True
|
||||
assert client.get(zmq.CURVE_SERVER) == False
|
||||
|
||||
with self.zap():
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
self.bounce(server, client)
|
||||
690
asm/venv/lib/python3.11/site-packages/zmq/tests/test_socket.py
Normal file
690
asm/venv/lib/python3.11/site-packages/zmq/tests/test_socket.py
Normal file
@@ -0,0 +1,690 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import copy
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, SkipTest, have_gevent, skip_pypy
|
||||
|
||||
pypy = platform.python_implementation().lower() == 'pypy'
|
||||
windows = platform.platform().lower().startswith('windows')
|
||||
on_ci = bool(os.environ.get('CI'))
|
||||
|
||||
# polling on windows is slow
|
||||
POLL_TIMEOUT = 1000 if windows else 100
|
||||
|
||||
|
||||
class TestSocket(BaseZMQTestCase):
|
||||
def test_create(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
# Superluminal protocol not yet implemented
|
||||
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a')
|
||||
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a')
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://')
|
||||
s.close()
|
||||
ctx.term()
|
||||
|
||||
def test_context_manager(self):
|
||||
url = 'inproc://a'
|
||||
with self.Context() as ctx:
|
||||
with ctx.socket(zmq.PUSH) as a:
|
||||
a.bind(url)
|
||||
with ctx.socket(zmq.PULL) as b:
|
||||
b.connect(url)
|
||||
msg = b'hi'
|
||||
a.send(msg)
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == msg
|
||||
assert b.closed == True
|
||||
assert a.closed == True
|
||||
assert ctx.closed == True
|
||||
|
||||
def test_connectbind_context_managers(self):
|
||||
url = 'inproc://a'
|
||||
msg = b'hi'
|
||||
with self.Context() as ctx:
|
||||
# Test connect() context manager
|
||||
with ctx.socket(zmq.PUSH) as a, ctx.socket(zmq.PULL) as b:
|
||||
a.bind(url)
|
||||
connect_context = b.connect(url)
|
||||
assert f'connect={url!r}' in repr(connect_context)
|
||||
with connect_context:
|
||||
a.send(msg)
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == msg
|
||||
# b should now be disconnected, so sending and receiving don't work
|
||||
with pytest.raises(zmq.Again):
|
||||
a.send(msg, flags=zmq.DONTWAIT)
|
||||
with pytest.raises(zmq.Again):
|
||||
b.recv(flags=zmq.DONTWAIT)
|
||||
a.unbind(url)
|
||||
# Test bind() context manager
|
||||
with ctx.socket(zmq.PUSH) as a, ctx.socket(zmq.PULL) as b:
|
||||
# unbind() just stops accepting of new connections, so we have to disconnect to test that
|
||||
# unbind happened.
|
||||
bind_context = a.bind(url)
|
||||
assert f'bind={url!r}' in repr(bind_context)
|
||||
with bind_context:
|
||||
b.connect(url)
|
||||
a.send(msg)
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == msg
|
||||
b.disconnect(url)
|
||||
b.connect(url)
|
||||
# Since a is unbound from url, b is not connected to anything
|
||||
with pytest.raises(zmq.Again):
|
||||
a.send(msg, flags=zmq.DONTWAIT)
|
||||
with pytest.raises(zmq.Again):
|
||||
b.recv(flags=zmq.DONTWAIT)
|
||||
|
||||
_repr_cls = "zmq.Socket"
|
||||
|
||||
def test_repr(self):
|
||||
with self.context.socket(zmq.PUSH) as s:
|
||||
assert f'{self._repr_cls}(zmq.PUSH)' in repr(s)
|
||||
assert 'closed' not in repr(s)
|
||||
assert f'{self._repr_cls}(zmq.PUSH)' in repr(s)
|
||||
assert 'closed' in repr(s)
|
||||
|
||||
def test_dir(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
assert 'send' in dir(s)
|
||||
assert 'IDENTITY' in dir(s)
|
||||
assert 'AFFINITY' in dir(s)
|
||||
assert 'FD' in dir(s)
|
||||
s.close()
|
||||
ctx.term()
|
||||
|
||||
@mark.skipif(mock is None, reason="requires unittest.mock")
|
||||
def test_mockable(self):
|
||||
s = self.socket(zmq.SUB)
|
||||
m = mock.Mock(spec=s)
|
||||
s.close()
|
||||
|
||||
def test_bind_unicode(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
p = s.bind_to_random_port("tcp://*")
|
||||
|
||||
def test_connect_unicode(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
s.connect("tcp://127.0.0.1:5555")
|
||||
|
||||
def test_bind_to_random_port(self):
|
||||
# Check that bind_to_random_port do not hide useful exception
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
# Invalid format
|
||||
try:
|
||||
s.bind_to_random_port('tcp:*')
|
||||
except zmq.ZMQError as e:
|
||||
assert e.errno == zmq.EINVAL
|
||||
# Invalid protocol
|
||||
try:
|
||||
s.bind_to_random_port('rand://*')
|
||||
except zmq.ZMQError as e:
|
||||
assert e.errno == zmq.EPROTONOSUPPORT
|
||||
|
||||
s.close()
|
||||
ctx.term()
|
||||
|
||||
def test_bind_connect_addr_error(self):
|
||||
with self.socket(zmq.PUSH) as s:
|
||||
url = "tcp://1.2.3.4.5:1234567"
|
||||
with pytest.raises(zmq.ZMQError) as exc:
|
||||
s.bind(url)
|
||||
assert url in str(exc.value)
|
||||
|
||||
url = "noproc://no/such/file"
|
||||
with pytest.raises(zmq.ZMQError) as exc:
|
||||
s.connect(url)
|
||||
assert url in str(exc.value)
|
||||
|
||||
def test_identity(self):
|
||||
s = self.context.socket(zmq.PULL)
|
||||
self.sockets.append(s)
|
||||
ident = b'identity\0\0'
|
||||
s.identity = ident
|
||||
assert s.get(zmq.IDENTITY) == ident
|
||||
|
||||
def test_unicode_sockopts(self):
|
||||
"""test setting/getting sockopts with unicode strings"""
|
||||
topic = "tést"
|
||||
p, s = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
assert s.send_unicode == s.send_unicode
|
||||
assert p.recv_unicode == p.recv_unicode
|
||||
self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic)
|
||||
self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic)
|
||||
s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16')
|
||||
self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic)
|
||||
s.setsockopt_unicode(zmq.SUBSCRIBE, topic)
|
||||
self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY)
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE)
|
||||
|
||||
identb = s.getsockopt(zmq.IDENTITY)
|
||||
identu = identb.decode('utf16')
|
||||
identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16')
|
||||
assert identu == identu2
|
||||
time.sleep(0.1) # wait for connection/subscription
|
||||
p.send_unicode(topic, zmq.SNDMORE)
|
||||
p.send_unicode(topic * 2, encoding='latin-1')
|
||||
assert topic == s.recv_unicode()
|
||||
assert topic * 2 == s.recv_unicode(encoding='latin-1')
|
||||
|
||||
def test_int_sockopts(self):
|
||||
"test integer sockopts"
|
||||
v = zmq.zmq_version_info()
|
||||
if v < (3, 0):
|
||||
default_hwm = 0
|
||||
else:
|
||||
default_hwm = 1000
|
||||
p, s = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
p.setsockopt(zmq.LINGER, 0)
|
||||
assert p.getsockopt(zmq.LINGER) == 0
|
||||
p.setsockopt(zmq.LINGER, -1)
|
||||
assert p.getsockopt(zmq.LINGER) == -1
|
||||
assert p.hwm == default_hwm
|
||||
p.hwm = 11
|
||||
assert p.hwm == 11
|
||||
# p.setsockopt(zmq.EVENTS, zmq.POLLIN)
|
||||
assert p.getsockopt(zmq.EVENTS) == zmq.POLLOUT
|
||||
self.assertRaisesErrno(zmq.EINVAL, p.setsockopt, zmq.EVENTS, 2**7 - 1)
|
||||
assert p.getsockopt(zmq.TYPE) == p.socket_type
|
||||
assert p.getsockopt(zmq.TYPE) == zmq.PUB
|
||||
assert s.getsockopt(zmq.TYPE) == s.socket_type
|
||||
assert s.getsockopt(zmq.TYPE) == zmq.SUB
|
||||
|
||||
# check for overflow / wrong type:
|
||||
errors = []
|
||||
backref = {}
|
||||
constants = zmq.constants
|
||||
for name in constants.__all__:
|
||||
value = getattr(constants, name)
|
||||
if isinstance(value, int):
|
||||
backref[value] = name
|
||||
for opt in zmq.constants.SocketOption:
|
||||
if opt._opt_type not in {
|
||||
zmq.constants._OptType.int,
|
||||
zmq.constants._OptType.int64,
|
||||
}:
|
||||
continue
|
||||
if opt.name.startswith(
|
||||
(
|
||||
'HWM',
|
||||
'ROUTER',
|
||||
'XPUB',
|
||||
'TCP',
|
||||
'FAIL',
|
||||
'REQ_',
|
||||
'CURVE_',
|
||||
'PROBE_ROUTER',
|
||||
'IPC_FILTER',
|
||||
'GSSAPI',
|
||||
'STREAM_',
|
||||
'VMCI_BUFFER_SIZE',
|
||||
'VMCI_BUFFER_MIN_SIZE',
|
||||
'VMCI_BUFFER_MAX_SIZE',
|
||||
'VMCI_CONNECT_TIMEOUT',
|
||||
'BLOCKY',
|
||||
'IN_BATCH_SIZE',
|
||||
'OUT_BATCH_SIZE',
|
||||
'WSS_TRUST_SYSTEM',
|
||||
'ONLY_FIRST_SUBSCRIBE',
|
||||
'PRIORITY',
|
||||
'RECONNECT_STOP',
|
||||
)
|
||||
):
|
||||
# some sockopts are write-only
|
||||
continue
|
||||
try:
|
||||
n = p.getsockopt(opt)
|
||||
except zmq.ZMQError as e:
|
||||
errors.append(f"getsockopt({opt!r}) raised {e}.")
|
||||
else:
|
||||
if n > 2**31:
|
||||
errors.append(
|
||||
f"getsockopt({opt!r}) returned a ridiculous value."
|
||||
" It is probably the wrong type."
|
||||
)
|
||||
if errors:
|
||||
self.fail('\n'.join([''] + errors))
|
||||
|
||||
def test_bad_sockopts(self):
|
||||
"""Test that appropriate errors are raised on bad socket options"""
|
||||
s = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(s)
|
||||
s.setsockopt(zmq.LINGER, 0)
|
||||
# unrecognized int sockopts pass through to libzmq, and should raise EINVAL
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5)
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999)
|
||||
# but only int sockopts are allowed through this way, otherwise raise a TypeError
|
||||
self.assertRaises(TypeError, s.setsockopt, 9999, b"5")
|
||||
# some sockopts are valid in general, but not on every socket:
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi')
|
||||
|
||||
def test_sockopt_roundtrip(self):
|
||||
"test set/getsockopt roundtrip."
|
||||
p = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(p)
|
||||
p.setsockopt(zmq.LINGER, 11)
|
||||
assert p.getsockopt(zmq.LINGER) == 11
|
||||
|
||||
def test_send_unicode(self):
|
||||
"test sending unicode objects"
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
self.sockets.extend([a, b])
|
||||
u = "çπ§"
|
||||
self.assertRaises(TypeError, a.send, u, copy=False)
|
||||
self.assertRaises(TypeError, a.send, u, copy=True)
|
||||
a.send_unicode(u)
|
||||
s = b.recv()
|
||||
assert s == u.encode('utf8')
|
||||
assert s.decode('utf8') == u
|
||||
a.send_unicode(u, encoding='utf16')
|
||||
s = b.recv_unicode(encoding='utf16')
|
||||
assert s == u
|
||||
|
||||
def test_send_multipart_check_type(self):
|
||||
"check type on all frames in send_multipart"
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
self.sockets.extend([a, b])
|
||||
self.assertRaises(TypeError, a.send_multipart, [b'a', 5])
|
||||
a.send_multipart([b'b'])
|
||||
rcvd = self.recv_multipart(b)
|
||||
assert rcvd == [b'b']
|
||||
|
||||
@skip_pypy
|
||||
def test_tracker(self):
|
||||
"test the MessageTracker object for tracking when zmq is done with a buffer"
|
||||
addr = 'tcp://127.0.0.1'
|
||||
# get a port:
|
||||
sock = socket.socket()
|
||||
sock.bind(('127.0.0.1', 0))
|
||||
port = sock.getsockname()[1]
|
||||
iface = "%s:%i" % (addr, port)
|
||||
sock.close()
|
||||
time.sleep(0.1)
|
||||
|
||||
a = self.context.socket(zmq.PUSH)
|
||||
b = self.context.socket(zmq.PULL)
|
||||
self.sockets.extend([a, b])
|
||||
a.connect(iface)
|
||||
time.sleep(0.1)
|
||||
p1 = a.send(b'something', copy=False, track=True)
|
||||
assert isinstance(p1, zmq.MessageTracker)
|
||||
assert p1 is zmq._FINISHED_TRACKER
|
||||
# small message, should start done
|
||||
assert p1.done
|
||||
|
||||
# disable zero-copy threshold
|
||||
a.copy_threshold = 0
|
||||
|
||||
p2 = a.send_multipart([b'something', b'else'], copy=False, track=True)
|
||||
assert isinstance(p2, zmq.MessageTracker)
|
||||
assert not p2.done
|
||||
|
||||
b.bind(iface)
|
||||
msg = self.recv_multipart(b)
|
||||
for i in range(10):
|
||||
if p1.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert p1.done == True
|
||||
assert msg == [b'something']
|
||||
msg = self.recv_multipart(b)
|
||||
for i in range(10):
|
||||
if p2.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert p2.done == True
|
||||
assert msg == [b'something', b'else']
|
||||
m = zmq.Frame(b"again", copy=False, track=True)
|
||||
assert m.tracker.done == False
|
||||
p1 = a.send(m, copy=False)
|
||||
p2 = a.send(m, copy=False)
|
||||
assert m.tracker.done == False
|
||||
assert p1.done == False
|
||||
assert p2.done == False
|
||||
msg = self.recv_multipart(b)
|
||||
assert m.tracker.done == False
|
||||
assert msg == [b'again']
|
||||
msg = self.recv_multipart(b)
|
||||
assert m.tracker.done == False
|
||||
assert msg == [b'again']
|
||||
assert p1.done == False
|
||||
assert p2.done == False
|
||||
m.tracker
|
||||
del m
|
||||
for i in range(10):
|
||||
if p1.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert p1.done == True
|
||||
assert p2.done == True
|
||||
m = zmq.Frame(b'something', track=False)
|
||||
self.assertRaises(ValueError, a.send, m, copy=False, track=True)
|
||||
|
||||
def test_close(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.close()
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.recv)
|
||||
ctx.term()
|
||||
|
||||
def test_attr(self):
|
||||
"""set setting/getting sockopts as attributes"""
|
||||
s = self.context.socket(zmq.DEALER)
|
||||
self.sockets.append(s)
|
||||
linger = 10
|
||||
s.linger = linger
|
||||
assert linger == s.linger
|
||||
assert linger == s.getsockopt(zmq.LINGER)
|
||||
assert s.fd == s.getsockopt(zmq.FD)
|
||||
|
||||
def test_bad_attr(self):
|
||||
s = self.context.socket(zmq.DEALER)
|
||||
self.sockets.append(s)
|
||||
try:
|
||||
s.apple = 'foo'
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("bad setattr should have raised AttributeError")
|
||||
try:
|
||||
s.apple
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("bad getattr should have raised AttributeError")
|
||||
|
||||
def test_subclass(self):
|
||||
"""subclasses can assign attributes"""
|
||||
|
||||
class S(zmq.Socket):
|
||||
a = None
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
self.a = -1
|
||||
super().__init__(*a, **kw)
|
||||
|
||||
s = S(self.context, zmq.REP)
|
||||
self.sockets.append(s)
|
||||
assert s.a == -1
|
||||
s.a = 1
|
||||
assert s.a == 1
|
||||
a = s.a
|
||||
assert a == 1
|
||||
|
||||
def test_recv_multipart(self):
|
||||
a, b = self.create_bound_pair()
|
||||
msg = b'hi'
|
||||
for i in range(3):
|
||||
a.send(msg)
|
||||
time.sleep(0.1)
|
||||
for i in range(3):
|
||||
assert self.recv_multipart(b) == [msg]
|
||||
|
||||
def test_close_after_destroy(self):
|
||||
"""s.close() after ctx.destroy() should be fine"""
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.REP)
|
||||
ctx.destroy()
|
||||
# reaper is not instantaneous
|
||||
time.sleep(1e-2)
|
||||
s.close()
|
||||
assert s.closed
|
||||
|
||||
def test_poll(self):
|
||||
a, b = self.create_bound_pair()
|
||||
time.time()
|
||||
evt = a.poll(POLL_TIMEOUT)
|
||||
assert evt == 0
|
||||
evt = a.poll(POLL_TIMEOUT, zmq.POLLOUT)
|
||||
assert evt == zmq.POLLOUT
|
||||
msg = b'hi'
|
||||
a.send(msg)
|
||||
evt = b.poll(POLL_TIMEOUT)
|
||||
assert evt == zmq.POLLIN
|
||||
msg2 = self.recv(b)
|
||||
evt = b.poll(POLL_TIMEOUT)
|
||||
assert evt == 0
|
||||
assert msg2 == msg
|
||||
|
||||
def test_ipc_path_max_length(self):
|
||||
"""IPC_PATH_MAX_LEN is a sensible value"""
|
||||
if zmq.IPC_PATH_MAX_LEN == 0:
|
||||
raise SkipTest("IPC_PATH_MAX_LEN undefined")
|
||||
|
||||
msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN
|
||||
assert zmq.IPC_PATH_MAX_LEN > 30, msg
|
||||
assert zmq.IPC_PATH_MAX_LEN < 1025, msg
|
||||
|
||||
def test_ipc_path_max_length_msg(self):
|
||||
if zmq.IPC_PATH_MAX_LEN == 0:
|
||||
raise SkipTest("IPC_PATH_MAX_LEN undefined")
|
||||
|
||||
s = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(s)
|
||||
try:
|
||||
s.bind('ipc://{}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1)))
|
||||
except zmq.ZMQError as e:
|
||||
assert str(zmq.IPC_PATH_MAX_LEN) in e.strerror
|
||||
|
||||
@mark.skipif(windows, reason="ipc not supported on Windows.")
|
||||
def test_ipc_path_no_such_file_or_directory_message(self):
|
||||
"""Display the ipc path in case of an ENOENT exception"""
|
||||
s = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(s)
|
||||
invalid_path = '/foo/bar'
|
||||
with pytest.raises(zmq.ZMQError) as error:
|
||||
s.bind(f'ipc://{invalid_path}')
|
||||
assert error.value.errno == errno.ENOENT
|
||||
error_message = str(error.value)
|
||||
assert invalid_path in error_message
|
||||
assert "no such file or directory" in error_message.lower()
|
||||
|
||||
def test_hwm(self):
|
||||
zmq3 = zmq.zmq_version_info()[0] >= 3
|
||||
for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER):
|
||||
s = self.context.socket(stype)
|
||||
s.hwm = 100
|
||||
assert s.hwm == 100
|
||||
if zmq3:
|
||||
try:
|
||||
assert s.sndhwm == 100
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
assert s.rcvhwm == 100
|
||||
except AttributeError:
|
||||
pass
|
||||
s.close()
|
||||
|
||||
def test_copy(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
scopy = copy.copy(s)
|
||||
sdcopy = copy.deepcopy(s)
|
||||
assert scopy._shadow
|
||||
assert sdcopy._shadow
|
||||
assert s.underlying == scopy.underlying
|
||||
assert s.underlying == sdcopy.underlying
|
||||
s.close()
|
||||
|
||||
def test_send_buffer(self):
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
for buffer_type in (memoryview, bytearray):
|
||||
rawbytes = str(buffer_type).encode('ascii')
|
||||
msg = buffer_type(rawbytes)
|
||||
a.send(msg)
|
||||
recvd = b.recv()
|
||||
assert recvd == rawbytes
|
||||
|
||||
def test_shadow(self):
|
||||
p = self.socket(zmq.PUSH)
|
||||
p.bind("tcp://127.0.0.1:5555")
|
||||
p2 = zmq.Socket.shadow(p.underlying)
|
||||
assert p.underlying == p2.underlying
|
||||
s = self.socket(zmq.PULL)
|
||||
s2 = zmq.Socket.shadow(s)
|
||||
assert s2._shadow_obj is s
|
||||
assert s.underlying != p.underlying
|
||||
assert s2.underlying == s.underlying
|
||||
s3 = zmq.Socket(s)
|
||||
assert s3._shadow_obj is s
|
||||
assert s3.underlying == s.underlying
|
||||
s2.connect("tcp://127.0.0.1:5555")
|
||||
sent = b'hi'
|
||||
p2.send(sent)
|
||||
rcvd = self.recv(s2)
|
||||
assert rcvd == sent
|
||||
|
||||
def test_shadow_pyczmq(self):
|
||||
try:
|
||||
from pyczmq import zctx, zsocket
|
||||
except Exception:
|
||||
raise SkipTest("Requires pyczmq")
|
||||
|
||||
ctx = zctx.new()
|
||||
ca = zsocket.new(ctx, zmq.PUSH)
|
||||
cb = zsocket.new(ctx, zmq.PULL)
|
||||
a = zmq.Socket.shadow(ca)
|
||||
b = zmq.Socket.shadow(cb)
|
||||
a.bind("inproc://a")
|
||||
b.connect("inproc://a")
|
||||
a.send(b'hi')
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == b'hi'
|
||||
|
||||
def test_subscribe_method(self):
|
||||
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
sub.subscribe('prefix')
|
||||
sub.subscribe = 'c'
|
||||
p = zmq.Poller()
|
||||
p.register(sub, zmq.POLLIN)
|
||||
# wait for subscription handshake
|
||||
for i in range(100):
|
||||
pub.send(b'canary')
|
||||
events = p.poll(250)
|
||||
if events:
|
||||
break
|
||||
self.recv(sub)
|
||||
pub.send(b'prefixmessage')
|
||||
msg = self.recv(sub)
|
||||
assert msg == b'prefixmessage'
|
||||
sub.unsubscribe('prefix')
|
||||
pub.send(b'prefixmessage')
|
||||
events = p.poll(1000)
|
||||
assert events == []
|
||||
|
||||
# CI often can't handle how much memory PyPy uses on this test
|
||||
@mark.skipif(
|
||||
(pypy and on_ci) or (sys.maxsize < 2**32) or (windows),
|
||||
reason="only run on 64b and not on CI.",
|
||||
)
|
||||
@mark.large
|
||||
def test_large_send(self):
|
||||
c = os.urandom(1)
|
||||
N = 2**31 + 1
|
||||
try:
|
||||
buf = c * N
|
||||
except MemoryError as e:
|
||||
raise SkipTest("Not enough memory: %s" % e)
|
||||
a, b = self.create_bound_pair()
|
||||
try:
|
||||
a.send(buf, copy=False)
|
||||
rcvd = b.recv(copy=False)
|
||||
except MemoryError as e:
|
||||
raise SkipTest("Not enough memory: %s" % e)
|
||||
# sample the front and back of the received message
|
||||
# without checking the whole content
|
||||
byte = ord(c)
|
||||
view = memoryview(rcvd)
|
||||
assert len(view) == N
|
||||
assert view[0] == byte
|
||||
assert view[-1] == byte
|
||||
|
||||
def test_custom_serialize(self):
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
def serialize(msg):
|
||||
frames = []
|
||||
frames.extend(msg.get('identities', []))
|
||||
content = json.dumps(msg['content']).encode('utf8')
|
||||
frames.append(content)
|
||||
return frames
|
||||
|
||||
def deserialize(frames):
|
||||
identities = frames[:-1]
|
||||
content = json.loads(frames[-1].decode('utf8'))
|
||||
return {
|
||||
'identities': identities,
|
||||
'content': content,
|
||||
}
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
a.send_serialized(msg, serialize)
|
||||
recvd = b.recv_serialized(deserialize)
|
||||
assert recvd['content'] == msg['content']
|
||||
assert recvd['identities']
|
||||
# bounce back, tests identities
|
||||
b.send_serialized(recvd, serialize)
|
||||
r2 = a.recv_serialized(deserialize)
|
||||
assert r2['content'] == msg['content']
|
||||
assert not r2['identities']
|
||||
|
||||
|
||||
if have_gevent and not windows:
|
||||
import gevent
|
||||
|
||||
class TestSocketGreen(GreenTest, TestSocket):
|
||||
test_bad_attr = GreenTest.skip_green
|
||||
test_close_after_destroy = GreenTest.skip_green
|
||||
_repr_cls = "zmq.green.Socket"
|
||||
|
||||
def test_timeout(self):
|
||||
a, b = self.create_bound_pair()
|
||||
g = gevent.spawn_later(0.5, lambda: a.send(b'hi'))
|
||||
timeout = gevent.Timeout(0.1)
|
||||
timeout.start()
|
||||
self.assertRaises(gevent.Timeout, b.recv)
|
||||
g.kill()
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_warn_set_timeo(self):
|
||||
s = self.context.socket(zmq.REQ)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
s.rcvtimeo = 5
|
||||
s.close()
|
||||
assert len(w) == 1
|
||||
assert w[0].category == UserWarning
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_warn_get_timeo(self):
|
||||
s = self.context.socket(zmq.REQ)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
s.sndtimeo
|
||||
s.close()
|
||||
assert len(w) == 1
|
||||
assert w[0].category == UserWarning
|
||||
@@ -0,0 +1,9 @@
|
||||
from zmq.ssh.tunnel import select_random_ports
|
||||
|
||||
|
||||
def test_random_ports():
|
||||
for i in range(4096):
|
||||
ports = select_random_ports(10)
|
||||
assert len(ports) == 10
|
||||
for p in ports:
|
||||
assert ports.count(p) == 1
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
import zmq
|
||||
from zmq.sugar import version
|
||||
|
||||
|
||||
class TestVersion(TestCase):
|
||||
def test_pyzmq_version(self):
|
||||
vs = zmq.pyzmq_version()
|
||||
vs2 = zmq.__version__
|
||||
assert isinstance(vs, str)
|
||||
if zmq.__revision__:
|
||||
assert vs == '@'.join(vs2, zmq.__revision__)
|
||||
else:
|
||||
assert vs == vs2
|
||||
if version.VERSION_EXTRA:
|
||||
assert version.VERSION_EXTRA in vs
|
||||
assert version.VERSION_EXTRA in vs2
|
||||
|
||||
def test_pyzmq_version_info(self):
|
||||
info = zmq.pyzmq_version_info()
|
||||
assert isinstance(info, tuple)
|
||||
for n in info[:3]:
|
||||
assert isinstance(n, int)
|
||||
if version.VERSION_EXTRA:
|
||||
assert len(info) == 4
|
||||
assert info[-1] == float('inf')
|
||||
else:
|
||||
assert len(info) == 3
|
||||
|
||||
def test_zmq_version_info(self):
|
||||
info = zmq.zmq_version_info()
|
||||
assert isinstance(info, tuple)
|
||||
for n in info[:3]:
|
||||
assert isinstance(n, int)
|
||||
|
||||
def test_zmq_version(self):
|
||||
v = zmq.zmq_version()
|
||||
assert isinstance(v, str)
|
||||
@@ -0,0 +1,58 @@
|
||||
import sys
|
||||
import time
|
||||
from functools import wraps
|
||||
|
||||
from pytest import mark
|
||||
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
from zmq.utils.win32 import allow_interrupt
|
||||
|
||||
|
||||
def count_calls(f):
|
||||
@wraps(f)
|
||||
def _(*args, **kwds):
|
||||
try:
|
||||
return f(*args, **kwds)
|
||||
finally:
|
||||
_.__calls__ += 1
|
||||
|
||||
_.__calls__ = 0
|
||||
return _
|
||||
|
||||
|
||||
@mark.new_console
|
||||
class TestWindowsConsoleControlHandler(BaseZMQTestCase):
|
||||
@mark.new_console
|
||||
@mark.skipif(not sys.platform.startswith('win'), reason='Windows only test')
|
||||
def test_handler(self):
|
||||
@count_calls
|
||||
def interrupt_polling():
|
||||
print('Caught CTRL-C!')
|
||||
|
||||
from ctypes import windll
|
||||
from ctypes.wintypes import BOOL, DWORD
|
||||
|
||||
kernel32 = windll.LoadLibrary('kernel32')
|
||||
|
||||
# <http://msdn.microsoft.com/en-us/library/ms683155.aspx>
|
||||
GenerateConsoleCtrlEvent = kernel32.GenerateConsoleCtrlEvent
|
||||
GenerateConsoleCtrlEvent.argtypes = (DWORD, DWORD)
|
||||
GenerateConsoleCtrlEvent.restype = BOOL
|
||||
|
||||
# Simulate CTRL-C event while handler is active.
|
||||
try:
|
||||
with allow_interrupt(interrupt_polling) as context:
|
||||
result = GenerateConsoleCtrlEvent(0, 0)
|
||||
# Sleep so that we give time to the handler to
|
||||
# capture the Ctrl-C event.
|
||||
time.sleep(0.5)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
else:
|
||||
if result == 0:
|
||||
raise OSError()
|
||||
else:
|
||||
self.fail('Expecting `KeyboardInterrupt` exception!')
|
||||
|
||||
# Make sure our handler was called.
|
||||
assert interrupt_polling.__calls__ == 1
|
||||
65
asm/venv/lib/python3.11/site-packages/zmq/tests/test_z85.py
Normal file
65
asm/venv/lib/python3.11/site-packages/zmq/tests/test_z85.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Test Z85 encoding
|
||||
|
||||
confirm values and roundtrip with test values from the reference implementation.
|
||||
"""
|
||||
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from zmq.utils import z85
|
||||
|
||||
|
||||
class TestZ85(TestCase):
|
||||
def test_client_public(self):
|
||||
client_public = (
|
||||
b"\xBB\x88\x47\x1D\x65\xE2\x65\x9B"
|
||||
b"\x30\xC5\x5A\x53\x21\xCE\xBB\x5A"
|
||||
b"\xAB\x2B\x70\xA3\x98\x64\x5C\x26"
|
||||
b"\xDC\xA2\xB2\xFC\xB4\x3F\xC5\x18"
|
||||
)
|
||||
encoded = z85.encode(client_public)
|
||||
|
||||
assert encoded == b"Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID"
|
||||
decoded = z85.decode(encoded)
|
||||
assert decoded == client_public
|
||||
|
||||
def test_client_secret(self):
|
||||
client_secret = (
|
||||
b"\x7B\xB8\x64\xB4\x89\xAF\xA3\x67"
|
||||
b"\x1F\xBE\x69\x10\x1F\x94\xB3\x89"
|
||||
b"\x72\xF2\x48\x16\xDF\xB0\x1B\x51"
|
||||
b"\x65\x6B\x3F\xEC\x8D\xFD\x08\x88"
|
||||
)
|
||||
encoded = z85.encode(client_secret)
|
||||
|
||||
assert encoded == b"D:)Q[IlAW!ahhC2ac:9*A}h:p?([4%wOTJ%JR%cs"
|
||||
decoded = z85.decode(encoded)
|
||||
assert decoded == client_secret
|
||||
|
||||
def test_server_public(self):
|
||||
server_public = (
|
||||
b"\x54\xFC\xBA\x24\xE9\x32\x49\x96"
|
||||
b"\x93\x16\xFB\x61\x7C\x87\x2B\xB0"
|
||||
b"\xC1\xD1\xFF\x14\x80\x04\x27\xC5"
|
||||
b"\x94\xCB\xFA\xCF\x1B\xC2\xD6\x52"
|
||||
)
|
||||
encoded = z85.encode(server_public)
|
||||
|
||||
assert encoded == b"rq:rM>}U?@Lns47E1%kR.o@n%FcmmsL/@{H8]yf7"
|
||||
decoded = z85.decode(encoded)
|
||||
assert decoded == server_public
|
||||
|
||||
def test_server_secret(self):
|
||||
server_secret = (
|
||||
b"\x8E\x0B\xDD\x69\x76\x28\xB9\x1D"
|
||||
b"\x8F\x24\x55\x87\xEE\x95\xC5\xB0"
|
||||
b"\x4D\x48\x96\x3F\x79\x25\x98\x77"
|
||||
b"\xB4\x9C\xD9\x06\x3A\xEA\xD3\xB7"
|
||||
)
|
||||
encoded = z85.encode(server_secret)
|
||||
|
||||
assert encoded == b"JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6"
|
||||
decoded = z85.decode(encoded)
|
||||
assert decoded == server_secret
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
try:
|
||||
import tornado
|
||||
|
||||
from zmq.eventloop import zmqstream
|
||||
except ImportError:
|
||||
tornado = None # type: ignore
|
||||
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("io_loop")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def push_pull(socket):
|
||||
push = zmqstream.ZMQStream(socket(zmq.PUSH))
|
||||
pull = zmqstream.ZMQStream(socket(zmq.PULL))
|
||||
port = push.bind_to_random_port('tcp://127.0.0.1')
|
||||
pull.connect('tcp://127.0.0.1:%i' % port)
|
||||
return (push, pull)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def push(push_pull):
|
||||
push, pull = push_pull
|
||||
return push
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pull(push_pull):
|
||||
push, pull = push_pull
|
||||
return pull
|
||||
|
||||
|
||||
async def test_callable_check(pull):
|
||||
"""Ensure callable check works."""
|
||||
|
||||
pull.on_send(lambda *args: None)
|
||||
pull.on_recv(lambda *args: None)
|
||||
with pytest.raises(AssertionError):
|
||||
pull.on_recv(1)
|
||||
with pytest.raises(AssertionError):
|
||||
pull.on_send(1)
|
||||
with pytest.raises(AssertionError):
|
||||
pull.on_recv(zmq)
|
||||
|
||||
|
||||
async def test_on_recv_basic(push, pull):
|
||||
sent = [b'basic']
|
||||
push.send_multipart(sent)
|
||||
f = asyncio.Future()
|
||||
|
||||
def callback(msg):
|
||||
f.set_result(msg)
|
||||
|
||||
pull.on_recv(callback)
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == sent
|
||||
|
||||
|
||||
async def test_on_recv_wake(push, pull):
|
||||
sent = [b'wake']
|
||||
|
||||
f = asyncio.Future()
|
||||
pull.on_recv(f.set_result)
|
||||
await asyncio.sleep(0.5)
|
||||
push.send_multipart(sent)
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == sent
|
||||
|
||||
|
||||
async def test_on_recv_async(push, pull):
|
||||
if tornado.version_info < (5,):
|
||||
pytest.skip()
|
||||
sent = [b'wake']
|
||||
|
||||
f = asyncio.Future()
|
||||
|
||||
async def callback(msg):
|
||||
await asyncio.sleep(0.1)
|
||||
f.set_result(msg)
|
||||
|
||||
pull.on_recv(callback)
|
||||
await asyncio.sleep(0.5)
|
||||
push.send_multipart(sent)
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == sent
|
||||
|
||||
|
||||
async def test_on_recv_async_error(push, pull, caplog):
|
||||
sent = [b'wake']
|
||||
|
||||
f = asyncio.Future()
|
||||
|
||||
async def callback(msg):
|
||||
f.set_result(msg)
|
||||
1 / 0
|
||||
|
||||
pull.on_recv(callback)
|
||||
await asyncio.sleep(0.1)
|
||||
with caplog.at_level(logging.ERROR, logger=zmqstream.gen_log.name):
|
||||
push.send_multipart(sent)
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == sent
|
||||
# logging error takes a tick later
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
messages = [
|
||||
x.message
|
||||
for x in caplog.get_records("call")
|
||||
if x.name == zmqstream.gen_log.name
|
||||
]
|
||||
assert "Uncaught exception in ZMQStream callback" in "\n".join(messages)
|
||||
|
||||
|
||||
async def test_shadow_socket(context):
|
||||
with context.socket(zmq.PUSH, socket_class=zmq.asyncio.Socket) as socket:
|
||||
with pytest.warns(RuntimeWarning):
|
||||
stream = zmqstream.ZMQStream(socket)
|
||||
assert type(stream.socket) is zmq.Socket
|
||||
assert stream.socket.underlying == socket.underlying
|
||||
stream.close()
|
||||
|
||||
|
||||
async def test_shadow_socket_close(context, caplog):
|
||||
with context.socket(zmq.PUSH) as push, context.socket(zmq.PULL) as pull:
|
||||
push.linger = pull.linger = 0
|
||||
port = push.bind_to_random_port('tcp://127.0.0.1')
|
||||
pull.connect(f'tcp://127.0.0.1:{port}')
|
||||
shadow_pull = zmq.Socket.shadow(pull)
|
||||
stream = zmqstream.ZMQStream(shadow_pull)
|
||||
# send some messages
|
||||
for i in range(10):
|
||||
push.send_string(str(i))
|
||||
# make sure at least one message has been delivered
|
||||
pull.recv()
|
||||
# register callback
|
||||
# this should schedule event callback on the next tick
|
||||
stream.on_recv(print)
|
||||
# close the shadowed socket
|
||||
pull.close()
|
||||
# run the event loop, which should see some events on the shadow socket
|
||||
# but the socket has been closed!
|
||||
with warnings.catch_warnings(record=True) as records:
|
||||
await asyncio.sleep(0.2)
|
||||
warning_text = "\n".join(str(r.message) for r in records)
|
||||
assert "after closing socket" in warning_text
|
||||
assert "closed socket" in caplog.text
|
||||
Reference in New Issue
Block a user