This commit is contained in:
Sven Riwoldt
2024-04-01 20:30:24 +02:00
parent fd333f3514
commit c7bc862c6f
6804 changed files with 1065135 additions and 0 deletions

View File

@@ -0,0 +1,264 @@
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
import os
import platform
import signal
import sys
import time
import warnings
from functools import partial
from threading import Thread
from typing import List
from unittest import SkipTest, TestCase
from pytest import mark
import zmq
from zmq.utils import jsonapi
try:
import gevent
from zmq import green as gzmq
have_gevent = True
except ImportError:
have_gevent = False
PYPY = platform.python_implementation() == 'PyPy'
# -----------------------------------------------------------------------------
# skip decorators (directly from unittest)
# -----------------------------------------------------------------------------
_id = lambda x: x
skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
# -----------------------------------------------------------------------------
# Base test class
# -----------------------------------------------------------------------------
def term_context(ctx, timeout):
"""Terminate a context with a timeout"""
t = Thread(target=ctx.term)
t.daemon = True
t.start()
t.join(timeout=timeout)
if t.is_alive():
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise RuntimeError(
"context could not terminate, open sockets likely remain in test"
)
class BaseZMQTestCase(TestCase):
green = False
teardown_timeout = 10
test_timeout_seconds = int(os.environ.get("ZMQ_TEST_TIMEOUT") or 60)
sockets: List[zmq.Socket]
@property
def _is_pyzmq_test(self):
return self.__class__.__module__.split(".", 1)[0] == __name__.split(".", 1)[0]
@property
def _should_test_timeout(self):
return (
self._is_pyzmq_test
and hasattr(signal, 'SIGALRM')
and self.test_timeout_seconds
)
@property
def Context(self):
if self.green:
return gzmq.Context
else:
return zmq.Context
def socket(self, socket_type):
s = self.context.socket(socket_type)
self.sockets.append(s)
return s
def _alarm_timeout(self, timeout, *args):
raise TimeoutError(f"Test did not complete in {timeout} seconds")
def setUp(self):
super().setUp()
if not self._is_pyzmq_test:
warnings.warn(
"zmq.tests.BaseZMQTestCase is deprecated in pyzmq 25, we recommend managing your own contexts and sockets.",
DeprecationWarning,
stacklevel=3,
)
if self.green and not have_gevent:
raise SkipTest("requires gevent")
self.context = self.Context.instance()
self.sockets = []
if self._should_test_timeout:
# use SIGALRM to avoid test hangs
signal.signal(
signal.SIGALRM, partial(self._alarm_timeout, self.test_timeout_seconds)
)
signal.alarm(self.test_timeout_seconds)
def tearDown(self):
if self._should_test_timeout:
# cancel the timeout alarm, if there was one
signal.alarm(0)
contexts = {self.context}
while self.sockets:
sock = self.sockets.pop()
contexts.add(sock.context) # in case additional contexts are created
sock.close(0)
for ctx in contexts:
try:
term_context(ctx, self.teardown_timeout)
except Exception:
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise
super().tearDown()
def create_bound_pair(
self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'
):
"""Create a bound socket pair using a random port."""
s1 = self.context.socket(type1)
s1.setsockopt(zmq.LINGER, 0)
port = s1.bind_to_random_port(interface)
s2 = self.context.socket(type2)
s2.setsockopt(zmq.LINGER, 0)
s2.connect(f'{interface}:{port}')
self.sockets.extend([s1, s2])
return s1, s2
def ping_pong(self, s1, s2, msg):
s1.send(msg)
msg2 = s2.recv()
s2.send(msg2)
msg3 = s1.recv()
return msg3
def ping_pong_json(self, s1, s2, o):
if jsonapi.jsonmod is None:
raise SkipTest("No json library")
s1.send_json(o)
o2 = s2.recv_json()
s2.send_json(o2)
o3 = s1.recv_json()
return o3
def ping_pong_pyobj(self, s1, s2, o):
s1.send_pyobj(o)
o2 = s2.recv_pyobj()
s2.send_pyobj(o2)
o3 = s1.recv_pyobj()
return o3
def assertRaisesErrno(self, errno, func, *args, **kwargs):
try:
func(*args, **kwargs)
except zmq.ZMQError as e:
self.assertEqual(
e.errno,
errno,
"wrong error raised, expected '%s' \
got '%s'"
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
)
else:
self.fail("Function did not raise any error")
def _select_recv(self, multipart, socket, **kwargs):
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
if zmq.zmq_version_info() >= (3, 1, 0):
# zmq 3.1 has a bug, where poll can return false positives,
# so we wait a little bit just in case
# See LIBZMQ-280 on JIRA
time.sleep(0.1)
r, w, x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
assert len(r) > 0, "Should have received a message"
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
recv = socket.recv_multipart if multipart else socket.recv
return recv(**kwargs)
def recv(self, socket, **kwargs):
"""call recv in a way that raises if there is nothing to receive"""
return self._select_recv(False, socket, **kwargs)
def recv_multipart(self, socket, **kwargs):
"""call recv_multipart in a way that raises if there is nothing to receive"""
return self._select_recv(True, socket, **kwargs)
class PollZMQTestCase(BaseZMQTestCase):
pass
class GreenTest:
"""Mixin for making green versions of test classes"""
green = True
teardown_timeout = 10
def assertRaisesErrno(self, errno, func, *args, **kwargs):
if errno == zmq.EAGAIN:
raise SkipTest("Skipping because we're green.")
try:
func(*args, **kwargs)
except zmq.ZMQError:
e = sys.exc_info()[1]
self.assertEqual(
e.errno,
errno,
"wrong error raised, expected '%s' \
got '%s'"
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
)
else:
self.fail("Function did not raise any error")
def tearDown(self):
if self._should_test_timeout:
# cancel the timeout alarm, if there was one
signal.alarm(0)
contexts = {self.context}
while self.sockets:
sock = self.sockets.pop()
contexts.add(sock.context) # in case additional contexts are created
sock.close()
try:
gevent.joinall(
[gevent.spawn(ctx.term) for ctx in contexts],
timeout=self.teardown_timeout,
raise_error=True,
)
except gevent.Timeout:
raise RuntimeError(
"context could not terminate, open sockets likely remain in test"
)
def skip_green(self):
raise SkipTest("Skipping because we are green")
def skip_green(f):
def skipping_test(self, *args, **kwargs):
if self.green:
raise SkipTest("Skipping because we are green")
else:
return f(self, *args, **kwargs)
return skipping_test

View File

@@ -0,0 +1,215 @@
"""pytest configuration and fixtures"""
import asyncio
import inspect
import os
import signal
import time
from functools import partial
from threading import Thread
try:
import tornado
from tornado import version_info
except ImportError:
tornado = None
else:
if version_info < (5,):
tornado = None
from tornado.ioloop import IOLoop
import pytest
import zmq
import zmq.asyncio
test_timeout_seconds = os.environ.get("ZMQ_TEST_TIMEOUT")
teardown_timeout = 10
def pytest_collection_modifyitems(items):
"""This function is automatically run by pytest passing all collected test
functions.
We use it to add asyncio marker to all async tests and assert we don't use
test functions that are async generators which wouldn't make sense.
It is no longer required with pytest-asyncio >= 0.17
"""
for item in items:
if inspect.iscoroutinefunction(item.obj):
item.add_marker('asyncio')
assert not inspect.isasyncgenfunction(item.obj)
@pytest.fixture
async def io_loop(event_loop, request):
"""Create tornado io_loop on current asyncio event loop"""
if tornado is None:
pytest.skip()
io_loop = IOLoop.current()
assert asyncio.get_event_loop() is event_loop
assert io_loop.asyncio_loop is event_loop
def _close():
io_loop.close(all_fds=True)
request.addfinalizer(_close)
return io_loop
def term_context(ctx, timeout):
"""Terminate a context with a timeout"""
t = Thread(target=ctx.term)
t.daemon = True
t.start()
t.join(timeout=timeout)
if t.is_alive():
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise RuntimeError(
f"context {ctx} could not terminate, open sockets likely remain in test"
)
@pytest.fixture
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
# make sure selectors are cleared
assert dict(zmq.asyncio._selectors) == {}
@pytest.fixture
def sigalrm_timeout():
"""Set timeout using SIGALRM
Avoids infinite hang in context.term for an unclean context,
raising an error instead.
"""
if not hasattr(signal, "SIGALRM") or not test_timeout_seconds:
return
def _alarm_timeout(*args):
raise TimeoutError(f"Test did not complete in {test_timeout_seconds} seconds")
signal.signal(signal.SIGALRM, _alarm_timeout)
signal.alarm(test_timeout_seconds)
@pytest.fixture
def Context():
"""Context class fixture
Override in modules to specify a different class (e.g. zmq.green)
"""
return zmq.Context
@pytest.fixture
def contexts(sigalrm_timeout):
"""Fixture to track contexts used in tests
For cleanup purposes
"""
contexts = set()
yield contexts
for ctx in contexts:
try:
term_context(ctx, teardown_timeout)
except Exception:
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise
@pytest.fixture
def context(Context, contexts):
"""Fixture for shared context"""
ctx = Context()
contexts.add(ctx)
return ctx
@pytest.fixture
def sockets(contexts):
sockets = []
yield sockets
# ensure any tracked sockets get their contexts cleaned up
for socket in sockets:
contexts.add(socket.context)
# close sockets
for socket in sockets:
socket.close(linger=0)
@pytest.fixture
def socket(context, sockets):
"""Fixture to create sockets, while tracking them for cleanup"""
def new_socket(*args, **kwargs):
s = context.socket(*args, **kwargs)
sockets.append(s)
return s
return new_socket
def assert_raises_errno(errno):
try:
yield
except zmq.ZMQError as e:
assert (
e.errno == errno
), f"wrong error raised, expected {zmq.ZMQError(errno)} got {zmq.ZMQError(e.errno)}"
else:
pytest.fail(f"Expected {zmq.ZMQError(errno)}, no error raised")
def recv(socket, *, timeout=5, flags=0, multipart=False, **kwargs):
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
if zmq.zmq_version_info() >= (3, 1, 0):
# zmq 3.1 has a bug, where poll can return false positives,
# so we wait a little bit just in case
# See LIBZMQ-280 on JIRA
time.sleep(0.1)
r, w, x = zmq.select([socket], [], [], timeout=timeout)
assert r, "Should have received a message"
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
recv = socket.recv_multipart if multipart else socket.recv
return recv(flags=flags, **kwargs)
recv_multipart = partial(recv, multipart=True)
@pytest.fixture
def create_bound_pair(socket):
def create_bound_pair(type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
"""Create a bound socket pair using a random port."""
s1 = socket(type1)
s1.linger = 0
port = s1.bind_to_random_port(interface)
s2 = socket(type2)
s2.linger = 0
s2.connect(f'{interface}:{port}')
return s1, s2
return create_bound_pair
@pytest.fixture
def bound_pair(create_bound_pair):
return create_bound_pair()
@pytest.fixture
def push_pull(create_bound_pair):
return create_bound_pair(zmq.PUSH, zmq.PULL)
@pytest.fixture
def dealer_router(create_bound_pair):
return create_bound_pair(zmq.DEALER, zmq.ROUTER)

View File

@@ -0,0 +1,23 @@
from zmq cimport Context, Frame, Socket, libzmq
cdef inline Frame c_send_recv(Socket a, Socket b, bytes to_send):
cdef Frame msg = Frame(to_send)
a.send(msg)
cdef Frame recvd = b.recv(flags=0, copy=False)
return recvd
cpdef bytes send_recv_test(bytes to_send):
cdef Context ctx = Context()
cdef Socket a = Socket(ctx, libzmq.ZMQ_PUSH)
cdef Socket b = Socket(ctx, libzmq.ZMQ_PULL)
url = 'inproc://test'
a.bind(url)
b.connect(url)
cdef Frame recvd_frame = c_send_recv(a, b, to_send)
a.close()
b.close()
ctx.term()
cdef bytes recvd_bytes = recvd_frame.bytes
return recvd_bytes

View File

@@ -0,0 +1,387 @@
"""Test asyncio support"""
# Copyright (c) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import asyncio
import json
import os
import sys
from concurrent.futures import CancelledError
from multiprocessing import Process
import pytest
from pytest import mark
import zmq
import zmq.asyncio as zaio
@pytest.fixture
def Context(event_loop):
return zaio.Context
def test_socket_class(context):
with context.socket(zmq.PUSH) as s:
assert isinstance(s, zaio.Socket)
def test_instance_subclass_first(context):
actx = zmq.asyncio.Context.instance()
ctx = zmq.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is zmq.asyncio.Context
def test_instance_subclass_second(context):
with zmq.Context.instance() as ctx:
assert type(ctx) is zmq.Context
with zmq.asyncio.Context.instance() as actx:
assert type(actx) is zmq.asyncio.Context
async def test_recv_multipart(context, create_bound_pair):
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_multipart()
assert not f.done()
await a.send(b"hi")
recvd = await f
assert recvd == [b"hi"]
async def test_recv(create_bound_pair):
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv()
assert not f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.done()
assert f1.result() == b"hi"
assert recvd == b"there"
@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
async def test_recv_timeout(push_pull):
a, b = push_pull
b.rcvtimeo = 100
f1 = b.recv()
b.rcvtimeo = 1000
f2 = b.recv_multipart()
with pytest.raises(zmq.Again):
await f1
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f2.done()
assert recvd == [b"hi", b"there"]
@mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
async def test_send_timeout(socket):
s = socket(zmq.PUSH)
s.sndtimeo = 100
with pytest.raises(zmq.Again):
await s.send(b"not going anywhere")
async def test_recv_string(push_pull):
a, b = push_pull
f = b.recv_string()
assert not f.done()
msg = "πøøπ"
await a.send_string(msg)
recvd = await f
assert f.done()
assert f.result() == msg
assert recvd == msg
async def test_recv_json(push_pull):
a, b = push_pull
f = b.recv_json()
assert not f.done()
obj = dict(a=5)
await a.send_json(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
async def test_recv_json_cancelled(push_pull):
a, b = push_pull
f = b.recv_json()
assert not f.done()
f.cancel()
# cycle eventloop to allow cancel events to fire
await asyncio.sleep(0)
obj = dict(a=5)
await a.send_json(obj)
# CancelledError change in 3.8 https://bugs.python.org/issue32528
if sys.version_info < (3, 8):
with pytest.raises(CancelledError):
recvd = await f
else:
with pytest.raises(asyncio.exceptions.CancelledError):
recvd = await f
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await asyncio.sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == obj
async def test_recv_pyobj(push_pull):
a, b = push_pull
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
await a.send_pyobj(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
async def test_custom_serialize(create_bound_pair):
def serialize(msg):
frames = []
frames.extend(msg.get("identities", []))
content = json.dumps(msg["content"]).encode("utf8")
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode("utf8"))
return {
"identities": identities,
"content": content,
}
a, b = create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
"content": {
"a": 5,
"b": "bee",
}
}
await a.send_serialized(msg, serialize)
recvd = await b.recv_serialized(deserialize)
assert recvd["content"] == msg["content"]
assert recvd["identities"]
# bounce back, tests identities
await b.send_serialized(recvd, serialize)
r2 = await a.recv_serialized(deserialize)
assert r2["content"] == msg["content"]
assert not r2["identities"]
async def test_custom_serialize_error(dealer_router):
a, b = dealer_router
msg = {
"content": {
"a": 5,
"b": "bee",
}
}
with pytest.raises(TypeError):
await a.send_serialized(json, json.dumps)
await a.send(b"not json")
with pytest.raises(TypeError):
await b.recv_serialized(json.loads)
async def test_recv_dontwait(push_pull):
push, pull = push_pull
f = pull.recv(zmq.DONTWAIT)
with pytest.raises(zmq.Again):
await f
await push.send(b"ping")
await pull.poll() # ensure message will be waiting
f = pull.recv(zmq.DONTWAIT)
assert f.done()
msg = await f
assert msg == b"ping"
async def test_recv_cancel(push_pull):
a, b = push_pull
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.cancelled()
assert f2.done()
assert recvd == [b"hi", b"there"]
async def test_poll(push_pull):
a, b = push_pull
f = b.poll(timeout=0)
await asyncio.sleep(0)
assert f.result() == 0
f = b.poll(timeout=1)
assert not f.done()
evt = await f
assert evt == 0
f = b.poll(timeout=1000)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b"hi", b"there"]
async def test_poll_base_socket(sockets):
ctx = zmq.Context()
url = "inproc://test"
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
sockets.extend([a, b])
a.bind(url)
b.connect(url)
poller = zaio.Poller()
poller.register(b, zmq.POLLIN)
f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b"hi", b"there"]
async def test_poll_on_closed_socket(push_pull):
a, b = push_pull
f = b.poll(timeout=1)
b.close()
# The test might stall if we try to await f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
await asyncio.sleep(0)
if f.cancelled():
break
assert f.cancelled()
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Windows does not support polling on files",
)
async def test_poll_raw():
p = zaio.Poller()
# make a pipe
r, w = os.pipe()
r = os.fdopen(r, "rb")
w = os.fdopen(w, "wb")
# POLLOUT
p.register(r, zmq.POLLIN)
p.register(w, zmq.POLLOUT)
evts = await p.poll(timeout=1)
evts = dict(evts)
assert r.fileno() not in evts
assert w.fileno() in evts
assert evts[w.fileno()] == zmq.POLLOUT
# POLLIN
p.unregister(w)
w.write(b"x")
w.flush()
evts = await p.poll(timeout=1000)
evts = dict(evts)
assert r.fileno() in evts
assert evts[r.fileno()] == zmq.POLLIN
assert r.read(1) == b"x"
r.close()
w.close()
def test_multiple_loops(push_pull):
a, b = push_pull
async def test():
await a.send(b'buf')
msg = await b.recv()
assert msg == b'buf'
for i in range(3):
loop = asyncio.new_event_loop()
loop.run_until_complete(asyncio.wait_for(test(), timeout=10))
loop.close()
def test_shadow():
with zmq.Context() as ctx:
s = ctx.socket(zmq.PULL)
async_s = zaio.Socket(s)
assert isinstance(async_s, zaio.Socket)
assert async_s.underlying == s.underlying
assert async_s.type == s.type
async def test_poll_leak():
ctx = zmq.asyncio.Context()
with ctx, ctx.socket(zmq.PULL) as s:
assert len(s._recv_futures) == 0
for i in range(10):
f = asyncio.ensure_future(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN))
f.cancel()
await asyncio.sleep(0)
# one more sleep allows further chained cleanup
await asyncio.sleep(0.1)
assert len(s._recv_futures) == 0
class ProcessForTeardownTest(Process):
def run(self):
"""Leave context, socket and event loop upon implicit disposal"""
actx = zaio.Context.instance()
socket = actx.socket(zmq.PAIR)
socket.bind_to_random_port("tcp://127.0.0.1")
async def never_ending_task(socket):
await socket.recv() # never ever receive anything
loop = asyncio.new_event_loop()
coro = asyncio.wait_for(never_ending_task(socket), timeout=1)
try:
loop.run_until_complete(coro)
except asyncio.TimeoutError:
pass # expected timeout
else:
assert False, "never_ending_task was completed unexpectedly"
finally:
loop.close()
def test_process_teardown(request):
proc = ProcessForTeardownTest()
proc.start()
request.addfinalizer(proc.terminate)
proc.join(10) # starting new Python process may cost a lot
assert proc.exitcode is not None, "process teardown hangs"
assert proc.exitcode == 0, f"Python process died with code {proc.exitcode}"

View File

@@ -0,0 +1,416 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import asyncio
import logging
import os
import shutil
import sys
import warnings
from contextlib import contextmanager
import pytest
import zmq
import zmq.asyncio
import zmq.auth
from zmq.tests import SkipTest, skip_pypy
try:
import tornado
except ImportError:
tornado = None
@pytest.fixture
def Context(event_loop):
return zmq.asyncio.Context
@pytest.fixture
def create_certs(tmpdir):
"""Create CURVE certificates for a test"""
# Create temporary CURVE key pairs for this test run. We create all keys in a
# temp directory and then move them into the appropriate private or public
# directory.
base_dir = str(tmpdir.join("certs").mkdir())
keys_dir = os.path.join(base_dir, "certificates")
public_keys_dir = os.path.join(base_dir, 'public_keys')
secret_keys_dir = os.path.join(base_dir, 'private_keys')
os.mkdir(keys_dir)
os.mkdir(public_keys_dir)
os.mkdir(secret_keys_dir)
server_public_file, server_secret_file = zmq.auth.create_certificates(
keys_dir, "server"
)
client_public_file, client_secret_file = zmq.auth.create_certificates(
keys_dir, "client"
)
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key"):
shutil.move(
os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, '.')
)
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key_secret"):
shutil.move(
os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, '.')
)
return (public_keys_dir, secret_keys_dir)
def load_certs(secret_keys_dir):
"""Return server and client certificate keys"""
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
return server_public, server_secret, client_public, client_secret
@pytest.fixture
def public_keys_dir(create_certs):
public_keys_dir, secret_keys_dir = create_certs
return public_keys_dir
@pytest.fixture
def secret_keys_dir(create_certs):
public_keys_dir, secret_keys_dir = create_certs
return secret_keys_dir
@pytest.fixture
def certs(secret_keys_dir):
return load_certs(secret_keys_dir)
@pytest.fixture
async def _async_setup(request, event_loop):
"""pytest doesn't support async setup/teardown"""
instance = request.instance
await instance.async_setup()
yield
# make sure our teardown runs before the loop closes
instance.async_teardown()
@pytest.mark.usefixtures("_async_setup")
class AuthTest:
auth = None
async def async_setup(self):
self.context = zmq.asyncio.Context()
if zmq.zmq_version_info() < (4, 0):
raise SkipTest("security is new in libzmq 4.0")
try:
zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("security requires libzmq to have curve support")
# enable debug logging while we run tests
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
self.auth = self.make_auth()
await self.start_auth()
def async_teardown(self):
if self.auth:
self.auth.stop()
self.auth = None
self.context.term()
def make_auth(self):
raise NotImplementedError()
async def start_auth(self):
self.auth.start()
async def can_connect(self, server, client, timeout=1000):
"""Check if client can connect to server using tcp transport"""
result = False
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = [b"Hello World"]
# run poll on server twice
# to flush spurious events
await server.poll(100, zmq.POLLOUT)
if await server.poll(timeout, zmq.POLLOUT):
try:
await server.send_multipart(msg, zmq.NOBLOCK)
except zmq.Again:
warnings.warn("server set POLLOUT, but cannot send", RuntimeWarning)
return False
else:
return False
if await client.poll(timeout):
try:
rcvd_msg = await client.recv_multipart(zmq.NOBLOCK)
except zmq.Again:
warnings.warn("client set POLLIN, but cannot recv", RuntimeWarning)
else:
assert rcvd_msg == msg
result = True
return result
@contextmanager
def push_pull(self):
with self.context.socket(zmq.PUSH) as server, self.context.socket(
zmq.PULL
) as client:
server.linger = 0
server.sndtimeo = 2000
client.linger = 0
client.rcvtimeo = 2000
yield server, client
@contextmanager
def curve_push_pull(self, certs, client_key="ok"):
server_public, server_secret, client_public, client_secret = certs
with self.push_pull() as (server, client):
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
if client_key is not None:
client.curve_publickey = client_public
client.curve_secretkey = client_secret
if client_key == "ok":
client.curve_serverkey = server_public
else:
private, public = zmq.curve_keypair()
client.curve_serverkey = public
yield (server, client)
async def test_null(self):
"""threaded auth - NULL"""
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
self.auth.stop()
self.auth = None
# use a new context, so ZAP isn't inherited
self.context.term()
self.context = zmq.asyncio.Context()
with self.push_pull() as (server, client):
assert await self.can_connect(server, client)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet. The client connection
# should still be allowed.
with self.push_pull() as (server, client):
server.zap_domain = b'global'
assert await self.can_connect(server, client)
async def test_deny(self):
# deny 127.0.0.1, connection should fail
self.auth.deny('127.0.0.1')
with pytest.raises(ValueError):
self.auth.allow("127.0.0.2")
with self.push_pull() as (server, client):
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
assert not await self.can_connect(server, client, timeout=100)
async def test_allow(self):
# allow 127.0.0.1, connection should pass
self.auth.allow('127.0.0.1')
with pytest.raises(ValueError):
self.auth.deny("127.0.0.2")
with self.push_pull() as (server, client):
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
assert await self.can_connect(server, client)
@pytest.mark.parametrize(
"enabled, password, success",
[
(True, "correct", True),
(False, "correct", False),
(True, "incorrect", False),
],
)
async def test_plain(self, enabled, password, success):
"""threaded auth - PLAIN"""
# Try PLAIN authentication - without configuring server, connection should fail
with self.push_pull() as (server, client):
server.plain_server = True
if password:
client.plain_username = b'admin'
client.plain_password = password.encode("ascii")
if enabled:
self.auth.configure_plain(domain='*', passwords={'admin': 'correct'})
if success:
assert await self.can_connect(server, client)
else:
assert not await self.can_connect(server, client, timeout=100)
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
with self.push_pull() as (server, client):
assert await self.can_connect(server, client)
@pytest.mark.parametrize(
"client_key, location, success",
[
('ok', zmq.auth.CURVE_ALLOW_ANY, True),
('ok', "public_keys", True),
('bad', "public_keys", False),
(None, "public_keys", False),
],
)
async def test_curve(self, certs, public_keys_dir, client_key, location, success):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
# Try CURVE authentication - without configuring server, connection should fail
with self.curve_push_pull(certs, client_key=client_key) as (server, client):
if location:
if location == 'public_keys':
location = public_keys_dir
self.auth.configure_curve(domain='*', location=location)
if success:
assert await self.can_connect(server, client, timeout=100)
else:
assert not await self.can_connect(server, client, timeout=100)
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
# Try connecting using NULL and no authentication enabled, connection should pass
with self.push_pull() as (server, client):
assert await self.can_connect(server, client)
@pytest.mark.parametrize("key", ["ok", "wrong"])
@pytest.mark.parametrize("async_", [True, False])
async def test_curve_callback(self, certs, key, async_):
"""threaded auth - CURVE with callback authentication"""
self.auth.allow('127.0.0.1')
server_public, server_secret, client_public, client_secret = certs
class CredentialsProvider:
def __init__(self):
if key == "ok":
self.client = client_public
else:
self.client = server_public
def callback(self, domain, key):
if key == self.client:
return True
else:
return False
async def async_callback(self, domain, key):
await asyncio.sleep(0.1)
if key == self.client:
return True
else:
return False
if async_:
CredentialsProvider.callback = CredentialsProvider.async_callback
provider = CredentialsProvider()
self.auth.configure_curve_callback(credentials_provider=provider)
with self.curve_push_pull(certs) as (server, client):
if key == "ok":
assert await self.can_connect(server, client)
else:
assert not await self.can_connect(server, client, timeout=200)
@skip_pypy
async def test_curve_user_id(self, certs, public_keys_dir):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
server_public, server_secret, client_public, client_secret = certs
self.auth.configure_curve(domain='*', location=public_keys_dir)
# reverse server-client relationship, so server is PULL
with self.push_pull() as (client, server):
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert await self.can_connect(client, server)
# test default user-id map
await client.send(b'test')
msg = await server.recv(copy=False)
assert msg.bytes == b'test'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == client_public.decode("utf8")
# test custom user-id map
self.auth.curve_user_id = lambda client_key: 'custom'
with self.context.socket(zmq.PUSH) as client2:
client2.curve_publickey = client_public
client2.curve_secretkey = client_secret
client2.curve_serverkey = server_public
assert await self.can_connect(client2, server)
await client2.send(b'test2')
msg = await server.recv(copy=False)
assert msg.bytes == b'test2'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == 'custom'
class TestThreadAuthentication(AuthTest):
"""Test authentication running in a thread"""
def make_auth(self):
from zmq.auth.thread import ThreadAuthenticator
return ThreadAuthenticator(self.context)
@pytest.mark.skipif(
sys.platform == 'win32' and sys.version_info < (3, 7),
reason="flaky event loop cleanup on windows+py36",
)
class TestAsyncioAuthentication(AuthTest):
"""Test authentication running in a thread"""
def make_auth(self):
from zmq.auth.asyncio import AsyncioAuthenticator
return AsyncioAuthenticator(self.context)
async def test_ioloop_authenticator(context, event_loop, io_loop):
from tornado.ioloop import IOLoop
with warnings.catch_warnings():
from zmq.auth.ioloop import IOLoopAuthenticator
auth = IOLoopAuthenticator(context)
assert auth.context is context
loop = IOLoop(make_current=False)
with pytest.warns(DeprecationWarning):
auth = IOLoopAuthenticator(io_loop=loop)

View File

@@ -0,0 +1,303 @@
import time
from unittest import TestCase
from zmq.tests import SkipTest
try:
from zmq.backend.cffi import ( # type: ignore
IDENTITY,
POLLIN,
POLLOUT,
PULL,
PUSH,
REP,
REQ,
zmq_version_info,
)
from zmq.backend.cffi._cffi import C, ffi
have_ffi_backend = True
except ImportError:
have_ffi_backend = False
class TestCFFIBackend(TestCase):
def setUp(self):
if not have_ffi_backend:
raise SkipTest('CFFI not available')
def test_zmq_version_info(self):
version = zmq_version_info()
assert version[0] in range(2, 11)
def test_zmq_ctx_new_destroy(self):
ctx = C.zmq_ctx_new()
assert ctx != ffi.NULL
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_socket_open_close(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_setsockopt(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
identity = ffi.new('char[3]', b'zmq')
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
assert ret == 0
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_getsockopt(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
identity = ffi.new('char[]', b'zmq')
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
assert ret == 0
option_len = ffi.new('size_t*', 3)
option = ffi.new('char[3]')
ret = C.zmq_getsockopt(socket, IDENTITY, ffi.cast('void*', option), option_len)
assert ret == 0
assert ffi.string(ffi.cast('char*', option))[0:1] == b"z"
assert ffi.string(ffi.cast('char*', option))[1:2] == b"m"
assert ffi.string(ffi.cast('char*', option))[2:3] == b"q"
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_bind(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, 8)
assert 0 == C.zmq_bind(socket, b'tcp://*:4444')
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_bind_connect(self):
ctx = C.zmq_ctx_new()
socket1 = C.zmq_socket(ctx, PUSH)
socket2 = C.zmq_socket(ctx, PULL)
assert 0 == C.zmq_bind(socket1, b'tcp://*:4444')
assert 0 == C.zmq_connect(socket2, b'tcp://127.0.0.1:4444')
assert ctx != ffi.NULL
assert ffi.NULL != socket1
assert ffi.NULL != socket2
assert 0 == C.zmq_close(socket1)
assert 0 == C.zmq_close(socket2)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_msg_init_close(self):
zmq_msg = ffi.new('zmq_msg_t*')
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_init(zmq_msg)
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_init_size(self):
zmq_msg = ffi.new('zmq_msg_t*')
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_init_size(zmq_msg, 10)
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_init_data(self):
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
assert 0 == C.zmq_msg_init_data(
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
)
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_data(self):
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[]', b'Hello')
assert 0 == C.zmq_msg_init_data(
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
)
data = C.zmq_msg_data(zmq_msg)
assert ffi.NULL != zmq_msg
assert ffi.string(ffi.cast("char*", data)) == b'Hello'
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_send(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
assert 0 == C.zmq_bind(receiver, b'tcp://*:7777')
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:7777')
time.sleep(0.1)
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(
zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL,
)
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
assert 0 == C.zmq_msg_close(zmq_msg)
assert C.zmq_close(sender) == 0
assert C.zmq_close(receiver) == 0
assert C.zmq_ctx_destroy(ctx) == 0
def test_zmq_recv(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
assert 0 == C.zmq_bind(receiver, b'tcp://*:2222')
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:2222')
time.sleep(0.1)
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(
zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL,
)
zmq_msg2 = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg2)
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0)
assert 5 == C.zmq_msg_size(zmq_msg2)
assert (
b"Hello"
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
)
assert C.zmq_close(sender) == 0
assert C.zmq_close(receiver) == 0
assert C.zmq_ctx_destroy(ctx) == 0
def test_zmq_poll(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
r1 = C.zmq_bind(receiver, b'tcp://*:3333')
r2 = C.zmq_connect(sender, b'tcp://127.0.0.1:3333')
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(
zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL,
)
receiver_pollitem = ffi.new('zmq_pollitem_t*')
receiver_pollitem.socket = receiver
receiver_pollitem.fd = 0
receiver_pollitem.events = POLLIN | POLLOUT
receiver_pollitem.revents = 0
ret = C.zmq_poll(ffi.NULL, 0, 0)
assert ret == 0
ret = C.zmq_poll(receiver_pollitem, 1, 0)
assert ret == 0
ret = C.zmq_msg_send(zmq_msg, sender, 0)
print(ffi.string(C.zmq_strerror(C.zmq_errno())))
assert ret == 5
time.sleep(0.2)
ret = C.zmq_poll(receiver_pollitem, 1, 0)
assert ret == 1
assert int(receiver_pollitem.revents) & POLLIN
assert not int(receiver_pollitem.revents) & POLLOUT
zmq_msg2 = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg2)
ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0)
assert ret_recv == 5
assert 5 == C.zmq_msg_size(zmq_msg2)
assert (
b"Hello"
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
)
sender_pollitem = ffi.new('zmq_pollitem_t*')
sender_pollitem.socket = sender
sender_pollitem.fd = 0
sender_pollitem.events = POLLIN | POLLOUT
sender_pollitem.revents = 0
ret = C.zmq_poll(sender_pollitem, 1, 0)
assert ret == 0
zmq_msg_again = ffi.new('zmq_msg_t*')
message_again = ffi.new('char[11]', b'Hello Again')
C.zmq_msg_init_data(
zmq_msg_again,
ffi.cast('void*', message_again),
ffi.cast('size_t', 11),
ffi.NULL,
ffi.NULL,
)
assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0)
time.sleep(0.2)
assert 0 <= C.zmq_poll(sender_pollitem, 1, 0)
assert int(sender_pollitem.revents) & POLLIN
assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0)
assert 11 == C.zmq_msg_size(zmq_msg2)
assert (
b"Hello Again"
== ffi.buffer(C.zmq_msg_data(zmq_msg2), int(C.zmq_msg_size(zmq_msg2)))[:]
)
assert 0 == C.zmq_close(sender)
assert 0 == C.zmq_close(receiver)
assert 0 == C.zmq_ctx_destroy(ctx)
assert 0 == C.zmq_msg_close(zmq_msg)
assert 0 == C.zmq_msg_close(zmq_msg2)
assert 0 == C.zmq_msg_close(zmq_msg_again)

View File

@@ -0,0 +1,31 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import pytest
import zmq
import zmq.constants
def test_constants():
assert zmq.POLLIN is zmq.PollEvent.POLLIN
assert zmq.PUSH is zmq.SocketType.PUSH
assert zmq.constants.SUBSCRIBE is zmq.SocketOption.SUBSCRIBE
assert (
zmq.RECONNECT_STOP_AFTER_DISCONNECT
is zmq.constants.ReconnectStop.AFTER_DISCONNECT
)
def test_socket_options():
assert zmq.IDENTITY is zmq.SocketOption.ROUTING_ID
assert zmq.IDENTITY._opt_type is zmq.constants._OptType.bytes
assert zmq.AFFINITY._opt_type is zmq.constants._OptType.int64
assert zmq.CURVE_SERVER._opt_type is zmq.constants._OptType.int
assert zmq.FD._opt_type is zmq.constants._OptType.fd
@pytest.mark.parametrize("event_name", list(zmq.Event.__members__))
def test_event_reprs(event_name):
event = getattr(zmq.Event, event_name)
assert event_name in repr(event)

View File

@@ -0,0 +1,425 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import gc
import os
import sys
import time
from queue import Queue
from threading import Event, Thread
from unittest import mock
import pytest
from pytest import mark
import zmq
from zmq.tests import PYPY, BaseZMQTestCase, GreenTest, SkipTest
class KwargTestSocket(zmq.Socket):
test_kwarg_value = None
def __init__(self, *args, **kwargs):
self.test_kwarg_value = kwargs.pop('test_kwarg', None)
super().__init__(*args, **kwargs)
class KwargTestContext(zmq.Context):
_socket_class = KwargTestSocket
class TestContext(BaseZMQTestCase):
def test_init(self):
c1 = self.Context()
assert isinstance(c1, self.Context)
c1.term()
c2 = self.Context()
assert isinstance(c2, self.Context)
c2.term()
c3 = self.Context()
assert isinstance(c3, self.Context)
c3.term()
_repr_cls = "zmq.Context"
def test_repr(self):
with self.Context() as ctx:
assert f'{self._repr_cls}()' in repr(ctx)
assert 'closed' not in repr(ctx)
with ctx.socket(zmq.PUSH) as push:
assert f'{self._repr_cls}(1 socket)' in repr(ctx)
with ctx.socket(zmq.PULL) as pull:
assert f'{self._repr_cls}(2 sockets)' in repr(ctx)
assert f'{self._repr_cls}()' in repr(ctx)
assert 'closed' in repr(ctx)
def test_dir(self):
ctx = self.Context()
assert 'socket' in dir(ctx)
if zmq.zmq_version_info() > (3,):
assert 'IO_THREADS' in dir(ctx)
ctx.term()
@mark.skipif(mock is None, reason="requires unittest.mock")
def test_mockable(self):
m = mock.Mock(spec=self.context)
def test_term(self):
c = self.Context()
c.term()
assert c.closed
def test_context_manager(self):
with pytest.warns(ResourceWarning):
with self.Context() as ctx:
s = ctx.socket(zmq.PUSH)
# context exit destroys sockets
assert s.closed
assert ctx.closed
def test_fail_init(self):
self.assertRaisesErrno(zmq.EINVAL, self.Context, -1)
def test_term_hang(self):
rep, req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
req.setsockopt(zmq.LINGER, 0)
req.send(b'hello', copy=False)
req.close()
rep.close()
self.context.term()
def test_instance(self):
ctx = self.Context.instance()
c2 = self.Context.instance(io_threads=2)
assert c2 is ctx
c2.term()
c3 = self.Context.instance()
c4 = self.Context.instance()
assert not c3 is c2
assert not c3.closed
assert c3 is c4
def test_instance_subclass_first(self):
self.context.term()
class SubContext(zmq.Context):
pass
sctx = SubContext.instance()
ctx = zmq.Context.instance()
ctx.term()
sctx.term()
assert type(ctx) is zmq.Context
assert type(sctx) is SubContext
def test_instance_subclass_second(self):
self.context.term()
class SubContextInherit(zmq.Context):
pass
class SubContextNoInherit(zmq.Context):
_instance = None
ctx = zmq.Context.instance()
sctx = SubContextInherit.instance()
sctx2 = SubContextNoInherit.instance()
ctx.term()
sctx.term()
sctx2.term()
assert type(ctx) is zmq.Context
assert type(sctx) is zmq.Context
assert type(sctx2) is SubContextNoInherit
def test_instance_threadsafe(self):
self.context.term() # clear default context
q = Queue()
# slow context initialization,
# to ensure that we are both trying to create one at the same time
class SlowContext(self.Context):
def __init__(self, *a, **kw):
time.sleep(1)
super().__init__(*a, **kw)
def f():
q.put(SlowContext.instance())
# call ctx.instance() in several threads at once
N = 16
threads = [Thread(target=f) for i in range(N)]
[t.start() for t in threads]
# also call it in the main thread (not first)
ctx = SlowContext.instance()
assert isinstance(ctx, SlowContext)
# check that all the threads got the same context
for i in range(N):
thread_ctx = q.get(timeout=5)
assert thread_ctx is ctx
# cleanup
ctx.term()
[t.join(timeout=5) for t in threads]
def test_socket_passes_kwargs(self):
test_kwarg_value = 'testing one two three'
with KwargTestContext() as ctx:
with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket:
assert socket.test_kwarg_value is test_kwarg_value
def test_socket_class_arg(self):
class CustomSocket(zmq.Socket):
pass
with self.Context() as ctx:
with ctx.socket(zmq.PUSH, socket_class=CustomSocket) as s:
assert isinstance(s, CustomSocket)
def test_many_sockets(self):
"""opening and closing many sockets shouldn't cause problems"""
ctx = self.Context()
for i in range(16):
sockets = [ctx.socket(zmq.REP) for i in range(65)]
[s.close() for s in sockets]
# give the reaper a chance
time.sleep(1e-2)
ctx.term()
def test_sockopts(self):
"""setting socket options with ctx attributes"""
ctx = self.Context()
ctx.linger = 5
assert ctx.linger == 5
s = ctx.socket(zmq.REQ)
assert s.linger == 5
assert s.getsockopt(zmq.LINGER) == 5
s.close()
# check that subscribe doesn't get set on sockets that don't subscribe:
ctx.subscribe = b''
s = ctx.socket(zmq.REQ)
s.close()
ctx.term()
@mark.skipif(sys.platform.startswith('win'), reason='Segfaults on Windows')
def test_destroy(self):
"""Context.destroy should close sockets"""
ctx = self.Context()
sockets = [ctx.socket(zmq.REP) for i in range(65)]
# close half of the sockets
[s.close() for s in sockets[::2]]
ctx.destroy()
# reaper is not instantaneous
time.sleep(1e-2)
for s in sockets:
assert s.closed
def test_destroy_linger(self):
"""Context.destroy should set linger on closing sockets"""
req, rep = self.create_bound_pair(zmq.REQ, zmq.REP)
req.send(b'hi')
time.sleep(1e-2)
self.context.destroy(linger=0)
# reaper is not instantaneous
time.sleep(1e-2)
for s in (req, rep):
assert s.closed
def test_term_noclose(self):
"""Context.term won't close sockets"""
ctx = self.Context()
s = ctx.socket(zmq.REQ)
assert not s.closed
t = Thread(target=ctx.term)
t.start()
t.join(timeout=0.1)
assert t.is_alive(), "Context should be waiting"
s.close()
t.join(timeout=0.1)
assert not t.is_alive(), "Context should have closed"
def test_gc(self):
"""test close&term by garbage collection alone"""
if PYPY:
raise SkipTest("GC doesn't work ")
# test credit @dln (GH #137):
def gcf():
def inner():
ctx = self.Context()
ctx.socket(zmq.PUSH)
# can't seem to catch these with pytest.warns(ResourceWarning)
inner()
gc.collect()
t = Thread(target=gcf)
t.start()
t.join(timeout=1)
assert not t.is_alive(), "Garbage collection should have cleaned up context"
def test_cyclic_destroy(self):
"""ctx.destroy should succeed when cyclic ref prevents gc"""
# test credit @dln (GH #137):
class CyclicReference:
def __init__(self, parent=None):
self.parent = parent
def crash(self, sock):
self.sock = sock
self.child = CyclicReference(self)
def crash_zmq():
ctx = self.Context()
sock = ctx.socket(zmq.PULL)
c = CyclicReference()
c.crash(sock)
ctx.destroy()
crash_zmq()
def test_term_thread(self):
"""ctx.term should not crash active threads (#139)"""
ctx = self.Context()
evt = Event()
evt.clear()
def block():
s = ctx.socket(zmq.REP)
s.bind_to_random_port('tcp://127.0.0.1')
evt.set()
try:
s.recv()
except zmq.ZMQError as e:
assert e.errno == zmq.ETERM
return
finally:
s.close()
self.fail("recv should have been interrupted with ETERM")
t = Thread(target=block)
t.start()
evt.wait(1)
assert evt.is_set(), "sync event never fired"
time.sleep(0.01)
ctx.term()
t.join(timeout=1)
assert not t.is_alive(), "term should have interrupted s.recv()"
def test_destroy_no_sockets(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
s.bind_to_random_port('tcp://127.0.0.1')
s.close()
ctx.destroy()
assert s.closed
assert ctx.closed
def test_ctx_opts(self):
if zmq.zmq_version_info() < (3,):
raise SkipTest("context options require libzmq 3")
ctx = self.Context()
ctx.set(zmq.MAX_SOCKETS, 2)
assert ctx.get(zmq.MAX_SOCKETS) == 2
ctx.max_sockets = 100
assert ctx.max_sockets == 100
assert ctx.get(zmq.MAX_SOCKETS) == 100
def test_copy(self):
c1 = self.Context()
c2 = copy.copy(c1)
c2b = copy.deepcopy(c1)
c3 = copy.deepcopy(c2)
assert c2._shadow
assert c3._shadow
assert c1.underlying == c2.underlying
assert c1.underlying == c3.underlying
assert c1.underlying == c2b.underlying
s = c3.socket(zmq.PUB)
s.close()
c1.term()
def test_shadow(self):
ctx = self.Context()
ctx2 = self.Context.shadow(ctx.underlying)
assert ctx.underlying == ctx2.underlying
s = ctx.socket(zmq.PUB)
s.close()
del ctx2
assert not ctx.closed
s = ctx.socket(zmq.PUB)
ctx2 = self.Context.shadow(ctx)
with ctx2.socket(zmq.PUB) as s2:
pass
assert s2.closed
assert not s.closed
s.close()
ctx3 = self.Context(ctx)
assert ctx3.underlying == ctx.underlying
del ctx3
assert not ctx.closed
ctx.term()
self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB)
del ctx2
def test_shadow_pyczmq(self):
try:
from pyczmq import zctx, zsocket, zstr
except Exception:
raise SkipTest("Requires pyczmq")
ctx = zctx.new()
a = zsocket.new(ctx, zmq.PUSH)
zsocket.bind(a, "inproc://a")
ctx2 = self.Context.shadow_pyczmq(ctx)
b = ctx2.socket(zmq.PULL)
b.connect("inproc://a")
zstr.send(a, b'hi')
rcvd = self.recv(b)
assert rcvd == b'hi'
b.close()
@mark.skipif(sys.platform.startswith('win'), reason='No fork on Windows')
def test_fork_instance(self):
ctx = self.Context.instance()
parent_ctx_id = id(ctx)
r_fd, w_fd = os.pipe()
reader = os.fdopen(r_fd, 'r')
child_pid = os.fork()
if child_pid == 0:
ctx = self.Context.instance()
writer = os.fdopen(w_fd, 'w')
child_ctx_id = id(ctx)
ctx.term()
writer.write(str(child_ctx_id) + "\n")
writer.flush()
writer.close()
os._exit(0)
else:
os.close(w_fd)
child_id_s = reader.readline()
reader.close()
assert child_id_s
assert int(child_id_s) != parent_ctx_id
ctx.term()
if False: # disable green context tests
class TestContextGreen(GreenTest, TestContext):
"""gevent subclass of context tests"""
# skip tests that use real threads:
test_gc = GreenTest.skip_green
test_term_thread = GreenTest.skip_green
test_destroy_linger = GreenTest.skip_green
_repr_cls = "zmq.green.Context"

View File

@@ -0,0 +1,51 @@
import os
import sys
import pytest
import zmq
pyximport = pytest.importorskip("pyximport")
HERE = os.path.dirname(__file__)
cython_ext = os.path.join(HERE, "cython_ext.pyx")
@pytest.mark.skipif(
not os.path.exists(cython_ext),
reason=f"Requires cython test file {cython_ext}",
)
@pytest.mark.skipif(
'zmq.backend.cython' not in sys.modules, reason="Requires cython backend"
)
@pytest.mark.skipif(
sys.platform.startswith('win'), reason="Don't try runtime Cython on Windows"
)
@pytest.mark.parametrize('language_level', [3, 2])
def test_cython(language_level, request, tmpdir):
assert 'zmq.tests.cython_ext' not in sys.modules
importers = pyximport.install(
setup_args=dict(include_dirs=zmq.get_includes()),
language_level=language_level,
build_dir=str(tmpdir),
)
cython_ext = None
def unimport():
pyximport.uninstall(*importers)
sys.modules.pop('zmq.tests.cython_ext', None)
request.addfinalizer(unimport)
# this import tests the compilation
from . import cython_ext
assert hasattr(cython_ext, 'send_recv_test')
# call the compiled function
# this shouldn't do much
msg = b'my msg'
received = cython_ext.send_recv_test(msg)
assert received == msg

View File

@@ -0,0 +1,396 @@
import threading
from pytest import fixture, raises
import zmq
from zmq.decorators import context, socket
from zmq.tests import BaseZMQTestCase, term_context
##############################################
# Test cases for @context
##############################################
@fixture(autouse=True)
def term_context_instance(request):
request.addfinalizer(lambda: term_context(zmq.Context.instance(), timeout=10))
def test_ctx():
@context()
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
test()
def test_ctx_orig_args():
@context()
def f(foo, bar, ctx, baz=None):
assert isinstance(ctx, zmq.Context), ctx
assert foo == 42
assert bar is True
assert baz == 'mock'
f(42, True, baz='mock')
def test_ctx_arg_naming():
@context('myctx')
def test(myctx):
assert isinstance(myctx, zmq.Context), myctx
test()
def test_ctx_args():
@context('ctx', 5)
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_arg_kwarg():
@context('ctx', io_threads=5)
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_kw_naming():
@context(name='myctx')
def test(myctx):
assert isinstance(myctx, zmq.Context), myctx
test()
def test_ctx_kwargs():
@context(name='ctx', io_threads=5)
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_kwargs_default():
@context(name='ctx', io_threads=5)
def test(ctx=None):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_keyword_miss():
@context(name='ctx')
def test(other_name):
pass # the keyword ``ctx`` not found
with raises(TypeError):
test()
def test_ctx_multi_assign():
@context(name='ctx')
def test(ctx):
pass # explosion
with raises(TypeError):
test('mock')
def test_ctx_reinit():
result = {'foo': None, 'bar': None}
@context()
def f(key, ctx):
assert isinstance(ctx, zmq.Context), ctx
result[key] = ctx
foo_t = threading.Thread(target=f, args=('foo',))
bar_t = threading.Thread(target=f, args=('bar',))
foo_t.start()
bar_t.start()
foo_t.join()
bar_t.join()
assert result['foo'] is not None, result
assert result['bar'] is not None, result
assert result['foo'] is not result['bar'], result
def test_ctx_multi_thread():
@context()
@context()
def f(foo, bar):
assert isinstance(foo, zmq.Context), foo
assert isinstance(bar, zmq.Context), bar
assert len(set(map(id, [foo, bar]))) == 2, set(map(id, [foo, bar]))
threads = [threading.Thread(target=f) for i in range(8)]
[t.start() for t in threads]
[t.join() for t in threads]
##############################################
# Test cases for @socket
##############################################
def test_ctx_skt():
@context()
@socket(zmq.PUB)
def test(ctx, skt):
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(skt, zmq.Socket), skt
assert skt.type == zmq.PUB
test()
def test_skt_name():
@context()
@socket('myskt', zmq.PUB)
def test(ctx, myskt):
assert isinstance(myskt, zmq.Socket), myskt
assert isinstance(ctx, zmq.Context), ctx
assert myskt.type == zmq.PUB
test()
def test_skt_kwarg():
@context()
@socket(zmq.PUB, name='myskt')
def test(ctx, myskt):
assert isinstance(myskt, zmq.Socket), myskt
assert isinstance(ctx, zmq.Context), ctx
assert myskt.type == zmq.PUB
test()
def test_ctx_skt_name():
@context('ctx')
@socket('skt', zmq.PUB, context_name='ctx')
def test(ctx, skt):
assert isinstance(skt, zmq.Socket), skt
assert isinstance(ctx, zmq.Context), ctx
assert skt.type == zmq.PUB
test()
def test_skt_default_ctx():
@socket(zmq.PUB)
def test(skt):
assert isinstance(skt, zmq.Socket), skt
assert skt.context is zmq.Context.instance()
assert skt.type == zmq.PUB
test()
def test_skt_reinit():
result = {'foo': None, 'bar': None}
@socket(zmq.PUB)
def f(key, skt):
assert isinstance(skt, zmq.Socket), skt
result[key] = skt
foo_t = threading.Thread(target=f, args=('foo',))
bar_t = threading.Thread(target=f, args=('bar',))
foo_t.start()
bar_t.start()
foo_t.join()
bar_t.join()
assert result['foo'] is not None, result
assert result['bar'] is not None, result
assert result['foo'] is not result['bar'], result
def test_ctx_skt_reinit():
result = {'foo': {'ctx': None, 'skt': None}, 'bar': {'ctx': None, 'skt': None}}
@context()
@socket(zmq.PUB)
def f(key, ctx, skt):
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(skt, zmq.Socket), skt
result[key]['ctx'] = ctx
result[key]['skt'] = skt
foo_t = threading.Thread(target=f, args=('foo',))
bar_t = threading.Thread(target=f, args=('bar',))
foo_t.start()
bar_t.start()
foo_t.join()
bar_t.join()
assert result['foo']['ctx'] is not None, result
assert result['foo']['skt'] is not None, result
assert result['bar']['ctx'] is not None, result
assert result['bar']['skt'] is not None, result
assert result['foo']['ctx'] is not result['bar']['ctx'], result
assert result['foo']['skt'] is not result['bar']['skt'], result
def test_skt_type_miss():
@context()
@socket('myskt')
def f(ctx, myskt):
pass # the socket type is missing
with raises(TypeError):
f()
def test_multi_skts():
@socket(zmq.PUB)
@socket(zmq.SUB)
@socket(zmq.PUSH)
def test(pub, sub, push):
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert isinstance(push, zmq.Socket), push
assert pub.context is zmq.Context.instance()
assert sub.context is zmq.Context.instance()
assert push.context is zmq.Context.instance()
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
assert push.type == zmq.PUSH
test()
def test_multi_skts_single_ctx():
@context()
@socket(zmq.PUB)
@socket(zmq.SUB)
@socket(zmq.PUSH)
def test(ctx, pub, sub, push):
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert isinstance(push, zmq.Socket), push
assert pub.context is ctx
assert sub.context is ctx
assert push.context is ctx
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
assert push.type == zmq.PUSH
test()
def test_multi_skts_with_name():
@socket('foo', zmq.PUSH)
@socket('bar', zmq.SUB)
@socket('baz', zmq.PUB)
def test(foo, bar, baz):
assert isinstance(foo, zmq.Socket), foo
assert isinstance(bar, zmq.Socket), bar
assert isinstance(baz, zmq.Socket), baz
assert foo.context is zmq.Context.instance()
assert bar.context is zmq.Context.instance()
assert baz.context is zmq.Context.instance()
assert foo.type == zmq.PUSH
assert bar.type == zmq.SUB
assert baz.type == zmq.PUB
test()
def test_func_return():
@context()
def f(ctx):
assert isinstance(ctx, zmq.Context), ctx
return 'something'
assert f() == 'something'
def test_skt_multi_thread():
@socket(zmq.PUB)
@socket(zmq.SUB)
@socket(zmq.PUSH)
def f(pub, sub, push):
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert isinstance(push, zmq.Socket), push
assert pub.context is zmq.Context.instance()
assert sub.context is zmq.Context.instance()
assert push.context is zmq.Context.instance()
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
assert push.type == zmq.PUSH
assert len(set(map(id, [pub, sub, push]))) == 3
threads = [threading.Thread(target=f) for i in range(8)]
[t.start() for t in threads]
[t.join() for t in threads]
class TestMethodDecorators(BaseZMQTestCase):
@context()
@socket(zmq.PUB)
@socket(zmq.SUB)
def multi_skts_method(self, ctx, pub, sub, foo='bar'):
assert isinstance(self, TestMethodDecorators), self
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert foo == 'bar'
assert pub.context is ctx
assert sub.context is ctx
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
def test_multi_skts_method(self):
self.multi_skts_method()
def test_multi_skts_method_other_args(self):
@socket(zmq.PUB)
@socket(zmq.SUB)
def f(foo, pub, sub, bar=None):
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert foo == 'mock'
assert bar == 'fake'
assert pub.context is zmq.Context.instance()
assert sub.context is zmq.Context.instance()
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
f('mock', bar='fake')

View File

@@ -0,0 +1,168 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import zmq
from zmq import devices
from zmq.tests import PYPY, BaseZMQTestCase, GreenTest, SkipTest, have_gevent
if PYPY:
# cleanup of shared Context doesn't work on PyPy
devices.Device.context_factory = zmq.Context
class TestDevice(BaseZMQTestCase):
def test_device_types(self):
for devtype in (zmq.STREAMER, zmq.FORWARDER, zmq.QUEUE):
dev = devices.Device(devtype, zmq.PAIR, zmq.PAIR)
assert dev.device_type == devtype
del dev
def test_device_attributes(self):
dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB)
assert dev.in_type == zmq.SUB
assert dev.out_type == zmq.PUB
assert dev.device_type == zmq.QUEUE
assert dev.daemon == True
del dev
def test_single_socket_forwarder_connect(self):
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
req = self.context.socket(zmq.REQ)
port = req.bind_to_random_port('tcp://127.0.0.1')
dev.connect_in('tcp://127.0.0.1:%i' % port)
dev.start()
time.sleep(0.25)
msg = b'hello'
req.send(msg)
assert msg == self.recv(req)
del dev
req.close()
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
req = self.context.socket(zmq.REQ)
port = req.bind_to_random_port('tcp://127.0.0.1')
dev.connect_out('tcp://127.0.0.1:%i' % port)
dev.start()
time.sleep(0.25)
msg = b'hello again'
req.send(msg)
assert msg == self.recv(req)
del dev
req.close()
def test_single_socket_forwarder_bind(self):
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
req = self.context.socket(zmq.REQ)
req.connect('tcp://127.0.0.1:%i' % port)
dev.start()
time.sleep(0.25)
msg = b'hello'
req.send(msg)
assert msg == self.recv(req)
del dev
req.close()
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
req = self.context.socket(zmq.REQ)
req.connect('tcp://127.0.0.1:%i' % port)
dev.start()
time.sleep(0.25)
msg = b'hello again'
req.send(msg)
assert msg == self.recv(req)
del dev
req.close()
def test_device_bind_to_random_with_args(self):
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
iface = 'tcp://127.0.0.1'
ports = []
min, max = 5000, 5050
ports.extend(
[
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
]
)
for port in ports:
if port < min or port > max:
self.fail('Unexpected port number: %i' % port)
def test_device_bind_to_random_binderror(self):
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
iface = 'tcp://127.0.0.1'
try:
for i in range(11):
dev.bind_in_to_random_port(iface, min_port=10000, max_port=10010)
except zmq.ZMQBindError as e:
return
else:
self.fail('Should have failed')
def test_proxy(self):
if zmq.zmq_version_info() < (3, 2):
raise SkipTest("Proxies only in libzmq >= 3")
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
iface = 'tcp://127.0.0.1'
port = dev.bind_in_to_random_port(iface)
port2 = dev.bind_out_to_random_port(iface)
port3 = dev.bind_mon_to_random_port(iface)
dev.start()
time.sleep(0.25)
msg = b'hello'
push = self.context.socket(zmq.PUSH)
push.connect("%s:%i" % (iface, port))
pull = self.context.socket(zmq.PULL)
pull.connect("%s:%i" % (iface, port2))
mon = self.context.socket(zmq.PULL)
mon.connect("%s:%i" % (iface, port3))
push.send(msg)
self.sockets.extend([push, pull, mon])
assert msg == self.recv(pull)
assert msg == self.recv(mon)
def test_proxy_bind_to_random_with_args(self):
if zmq.zmq_version_info() < (3, 2):
raise SkipTest("Proxies only in libzmq >= 3")
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
iface = 'tcp://127.0.0.1'
ports = []
min, max = 5000, 5050
ports.extend(
[
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
]
)
for port in ports:
if port < min or port > max:
self.fail('Unexpected port number: %i' % port)
if have_gevent:
import gevent
import zmq.green
class TestDeviceGreen(GreenTest, BaseZMQTestCase):
def test_green_device(self):
rep = self.context.socket(zmq.REP)
req = self.context.socket(zmq.REQ)
self.sockets.extend([req, rep])
port = rep.bind_to_random_port('tcp://127.0.0.1')
g = gevent.spawn(zmq.green.device, zmq.QUEUE, rep, rep)
req.connect('tcp://127.0.0.1:%i' % port)
req.send(b'hi')
timeout = gevent.Timeout(3)
timeout.start()
receiver = gevent.spawn(req.recv)
assert receiver.get(2) == b'hi'
timeout.cancel()
g.kill(block=True)

View File

@@ -0,0 +1,47 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import pytest
import zmq
from zmq.tests import BaseZMQTestCase
class TestDraftSockets(BaseZMQTestCase):
def setUp(self):
if not zmq.DRAFT_API:
pytest.skip("draft api unavailable")
super().setUp()
def test_client_server(self):
client, server = self.create_bound_pair(zmq.CLIENT, zmq.SERVER)
client.send(b'request')
msg = self.recv(server, copy=False)
assert msg.routing_id is not None
server.send(b'reply', routing_id=msg.routing_id)
reply = self.recv(client)
assert reply == b'reply'
def test_radio_dish(self):
dish, radio = self.create_bound_pair(zmq.DISH, zmq.RADIO)
dish.rcvtimeo = 250
group = 'mygroup'
dish.join(group)
received_count = 0
received = set()
sent = set()
for i in range(10):
msg = str(i).encode('ascii')
sent.add(msg)
radio.send(msg, group=group)
try:
recvd = dish.recv()
except zmq.Again:
time.sleep(0.1)
else:
received.add(recvd)
received_count += 1
# assert that we got *something*
assert len(received.intersection(sent)) >= 5

View File

@@ -0,0 +1,37 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from threading import Thread
import zmq
from zmq import Again, ContextTerminated, ZMQError, strerror
from zmq.tests import BaseZMQTestCase
class TestZMQError(BaseZMQTestCase):
def test_strerror(self):
"""test that strerror gets the right type."""
for i in range(10):
e = strerror(i)
assert isinstance(e, str)
def test_zmqerror(self):
for errno in range(10):
e = ZMQError(errno)
assert e.errno == errno
assert str(e) == strerror(errno)
def test_again(self):
s = self.context.socket(zmq.REP)
self.assertRaises(Again, s.recv, zmq.NOBLOCK)
self.assertRaisesErrno(zmq.EAGAIN, s.recv, zmq.NOBLOCK)
s.close()
def atest_ctxterm(self):
s = self.context.socket(zmq.REP)
t = Thread(target=self.context.term)
t.start()
self.assertRaises(ContextTerminated, s.recv, zmq.NOBLOCK)
self.assertRaisesErrno(zmq.TERM, s.recv, zmq.NOBLOCK)
s.close()
t.join()

View File

@@ -0,0 +1,26 @@
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
from pytest import mark
import zmq
only_bundled = mark.skipif(not hasattr(zmq, '_libzmq'), reason="bundled libzmq")
@mark.skipif('zmq.zmq_version_info() < (4, 1)')
def test_has():
assert not zmq.has('something weird')
@only_bundled
def test_has_curve():
"""bundled libzmq has curve support"""
assert zmq.has('curve')
@only_bundled
def test_has_ipc():
"""bundled libzmq has ipc support"""
assert zmq.has('ipc')

View File

@@ -0,0 +1,34 @@
"""tests for extending pyzmq"""
import zmq
class CustomSocket(zmq.Socket):
custom_attr: int
def __init__(self, context, socket_type, custom_attr: int = 0):
super().__init__(context, socket_type)
self.custom_attr = custom_attr
class CustomContext(zmq.Context):
extra_arg: str
_socket_class = CustomSocket
def __init__(self, extra_arg: str = 'x'):
super().__init__()
self.extra_arg = extra_arg
def test_custom_context():
ctx = CustomContext('s')
assert isinstance(ctx, CustomContext)
assert ctx.extra_arg == 's'
s = ctx.socket(zmq.PUSH, custom_attr=10)
assert isinstance(s, CustomSocket)
assert s.custom_attr == 10
assert s.context is ctx
assert s.type == zmq.PUSH
s.close()
ctx.term()

View File

@@ -0,0 +1,354 @@
# Copyright (c) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import json
import os
import sys
from datetime import timedelta
import pytest
gen = pytest.importorskip('tornado.gen')
from tornado.ioloop import IOLoop
import zmq
from zmq.eventloop import future
from zmq.tests import BaseZMQTestCase
class TestFutureSocket(BaseZMQTestCase):
Context = future.Context
def setUp(self):
self.loop = IOLoop(make_current=False)
super().setUp()
def tearDown(self):
super().tearDown()
if self.loop:
self.loop.close(all_fds=True)
def test_socket_class(self):
s = self.context.socket(zmq.PUSH)
assert isinstance(s, future.Socket)
s.close()
def test_instance_subclass_first(self):
actx = self.Context.instance()
ctx = zmq.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is self.Context
def test_instance_subclass_second(self):
ctx = zmq.Context.instance()
actx = self.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is self.Context
def test_recv_multipart(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_multipart()
assert not f.done()
await a.send(b"hi")
recvd = await f
assert recvd == [b'hi']
self.loop.run_sync(test)
def test_recv(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv()
assert not f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.done()
assert f1.result() == b'hi'
assert recvd == b'there'
self.loop.run_sync(test)
def test_recv_cancel(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.cancelled()
assert f2.done()
assert recvd == [b'hi', b'there']
self.loop.run_sync(test)
@pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_recv_timeout(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
b.rcvtimeo = 100
f1 = b.recv()
b.rcvtimeo = 1000
f2 = b.recv_multipart()
with pytest.raises(zmq.Again):
await f1
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f2.done()
assert recvd == [b'hi', b'there']
self.loop.run_sync(test)
@pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_send_timeout(self):
async def test():
s = self.socket(zmq.PUSH)
s.sndtimeo = 100
with pytest.raises(zmq.Again):
await s.send(b"not going anywhere")
self.loop.run_sync(test)
def test_send_noblock(self):
async def test():
s = self.socket(zmq.PUSH)
with pytest.raises(zmq.Again):
await s.send(b"not going anywhere", flags=zmq.NOBLOCK)
self.loop.run_sync(test)
def test_send_multipart_noblock(self):
async def test():
s = self.socket(zmq.PUSH)
with pytest.raises(zmq.Again):
await s.send_multipart([b"not going anywhere"], flags=zmq.NOBLOCK)
self.loop.run_sync(test)
def test_recv_string(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_string()
assert not f.done()
msg = 'πøøπ'
await a.send_string(msg)
recvd = await f
assert f.done()
assert f.result() == msg
assert recvd == msg
self.loop.run_sync(test)
def test_recv_json(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
obj = dict(a=5)
await a.send_json(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
self.loop.run_sync(test)
def test_recv_json_cancelled(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
f.cancel()
# cycle eventloop to allow cancel events to fire
await gen.sleep(0)
obj = dict(a=5)
await a.send_json(obj)
with pytest.raises(future.CancelledError):
recvd = await f
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await gen.sleep(0)
# make sure cancelled recv didn't eat up event
recvd = await gen.with_timeout(timedelta(seconds=5), b.recv_json())
assert recvd == obj
self.loop.run_sync(test)
def test_recv_pyobj(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
await a.send_pyobj(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
self.loop.run_sync(test)
def test_custom_serialize(self):
def serialize(msg):
frames = []
frames.extend(msg.get('identities', []))
content = json.dumps(msg['content']).encode('utf8')
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode('utf8'))
return {
'identities': identities,
'content': content,
}
async def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
await a.send_serialized(msg, serialize)
recvd = await b.recv_serialized(deserialize)
assert recvd['content'] == msg['content']
assert recvd['identities']
# bounce back, tests identities
await b.send_serialized(recvd, serialize)
r2 = await a.recv_serialized(deserialize)
assert r2['content'] == msg['content']
assert not r2['identities']
self.loop.run_sync(test)
def test_custom_serialize_error(self):
async def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
with pytest.raises(TypeError):
await a.send_serialized(json, json.dumps)
await a.send(b"not json")
with pytest.raises(TypeError):
await b.recv_serialized(json.loads)
self.loop.run_sync(test)
def test_poll(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.poll(timeout=0)
assert f.done()
assert f.result() == 0
f = b.poll(timeout=1)
assert not f.done()
evt = await f
assert evt == 0
f = b.poll(timeout=1000)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b'hi', b'there']
self.loop.run_sync(test)
@pytest.mark.skipif(
sys.platform.startswith('win'), reason='Windows unsupported socket type'
)
def test_poll_base_socket(self):
async def test():
ctx = zmq.Context()
url = 'inproc://test'
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
self.sockets.extend([a, b])
a.bind(url)
b.connect(url)
poller = future.Poller()
poller.register(b, zmq.POLLIN)
f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b'hi', b'there'])
evt = await f
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b'hi', b'there']
a.close()
b.close()
ctx.term()
self.loop.run_sync(test)
def test_close_all_fds(self):
s = self.socket(zmq.PUB)
async def attach():
s._get_loop()
self.loop.run_sync(attach)
self.loop.close(all_fds=True)
self.loop = None # avoid second close later
assert s.closed
@pytest.mark.skipif(
sys.platform.startswith('win'),
reason='Windows does not support polling on files',
)
def test_poll_raw(self):
async def test():
p = future.Poller()
# make a pipe
r, w = os.pipe()
r = os.fdopen(r, 'rb')
w = os.fdopen(w, 'wb')
# POLLOUT
p.register(r, zmq.POLLIN)
p.register(w, zmq.POLLOUT)
evts = await p.poll(timeout=1)
evts = dict(evts)
assert r.fileno() not in evts
assert w.fileno() in evts
assert evts[w.fileno()] == zmq.POLLOUT
# POLLIN
p.unregister(w)
w.write(b'x')
w.flush()
evts = await p.poll(timeout=1000)
evts = dict(evts)
assert r.fileno() in evts
assert evts[r.fileno()] == zmq.POLLIN
assert r.read(1) == b'x'
r.close()
w.close()
self.loop.run_sync(test)

View File

@@ -0,0 +1,99 @@
"""
Test Imports - the quickest test to ensure that we haven't
introduced version-incompatible syntax errors.
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
# flake8: noqa: F401
import pytest
def test_toplevel():
"""test toplevel import"""
import zmq
def test_core():
"""test core imports"""
from zmq import (
Context,
Frame,
Poller,
Socket,
constants,
device,
proxy,
pyzmq_version,
pyzmq_version_info,
zmq_version,
zmq_version_info,
)
def test_devices():
"""test device imports"""
import zmq.devices
from zmq.devices import basedevice, monitoredqueue, monitoredqueuedevice
def test_log():
"""test log imports"""
import zmq.log
from zmq.log import handlers
def test_eventloop():
"""test eventloop imports"""
pytest.importorskip("tornado")
import zmq.eventloop
from zmq.eventloop import ioloop, zmqstream
def test_utils():
"""test util imports"""
import zmq.utils
from zmq.utils import jsonapi, strtypes
def test_ssh():
"""test ssh imports"""
from zmq.ssh import tunnel
def test_decorators():
"""test decorators imports"""
from zmq.decorators import context, socket
def test_zmq_all():
import zmq
for name in zmq.__all__:
assert hasattr(zmq, name)
@pytest.mark.parametrize("pkgname", ["zmq", "zmq.green"])
@pytest.mark.parametrize(
"attr",
[
"RCVTIMEO",
"PUSH",
"zmq_version_info",
"SocketOption",
"device",
"Socket",
"Context",
],
)
def test_all_exports(pkgname, attr):
import zmq
subpkg = pytest.importorskip(pkgname)
for name in zmq.__all__:
assert hasattr(subpkg, name)
assert attr in subpkg.__all__
if attr not in ("Socket", "Context", "device"):
assert getattr(subpkg, attr) is getattr(zmq, attr)

View File

@@ -0,0 +1,33 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
import zmq
class TestIncludes(TestCase):
def test_get_includes(self):
from os.path import basename
includes = zmq.get_includes()
assert isinstance(includes, list)
assert len(includes) >= 2
parent = includes[0]
assert isinstance(parent, str)
utilsdir = includes[1]
assert isinstance(utilsdir, str)
utils = basename(utilsdir)
assert utils == "utils"
def test_get_library_dirs(self):
from os.path import basename
libdirs = zmq.get_library_dirs()
assert isinstance(libdirs, list)
assert len(libdirs) == 1
parent = libdirs[0]
assert isinstance(parent, str)
libdir = basename(parent)
assert libdir == "zmq"

View File

@@ -0,0 +1,33 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import pytest
try:
import tornado.ioloop
except ImportError:
_tornado = False
else:
_tornado = True
def setup():
if not _tornado:
pytest.skip("requires tornado")
def test_ioloop():
# may have been imported before,
# can't capture the warning
from zmq.eventloop import ioloop
assert ioloop.IOLoop is tornado.ioloop.IOLoop
assert ioloop.ZMQIOLoop is ioloop.IOLoop
def test_ioloop_install():
from zmq.eventloop import ioloop
with pytest.warns(DeprecationWarning):
ioloop.install()

View File

@@ -0,0 +1,193 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import logging
import time
import zmq
from zmq.log import handlers
from zmq.tests import BaseZMQTestCase
class TestPubLog(BaseZMQTestCase):
iface = 'inproc://zmqlog'
topic = 'zmq'
@property
def logger(self):
# print dir(self)
logger = logging.getLogger('zmqtest')
logger.setLevel(logging.DEBUG)
return logger
def connect_handler(self, topic=None):
topic = self.topic if topic is None else topic
logger = self.logger
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
handler = handlers.PUBHandler(pub)
handler.setLevel(logging.DEBUG)
handler.root_topic = topic
logger.addHandler(handler)
sub.setsockopt(zmq.SUBSCRIBE, topic.encode())
time.sleep(0.1)
return logger, handler, sub
def test_init_iface(self):
logger = self.logger
ctx = self.context
handler = handlers.PUBHandler(self.iface)
assert not handler.ctx is ctx
self.sockets.append(handler.socket)
# handler.ctx.term()
handler = handlers.PUBHandler(self.iface, self.context)
self.sockets.append(handler.socket)
assert handler.ctx is ctx
handler.setLevel(logging.DEBUG)
handler.root_topic = self.topic
logger.addHandler(handler)
sub = ctx.socket(zmq.SUB)
self.sockets.append(sub)
sub.setsockopt(zmq.SUBSCRIBE, self.topic.encode())
sub.connect(self.iface)
import time
time.sleep(0.25)
msg1 = 'message'
logger.info(msg1)
(topic, msg2) = sub.recv_multipart()
assert topic == b'zmq.INFO'
assert msg2 == (msg1 + "\n").encode("utf8")
logger.removeHandler(handler)
def test_init_socket(self):
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
logger = self.logger
handler = handlers.PUBHandler(pub)
handler.setLevel(logging.DEBUG)
handler.root_topic = self.topic
logger.addHandler(handler)
assert handler.socket is pub
assert handler.ctx is pub.context
assert handler.ctx is self.context
sub.setsockopt(zmq.SUBSCRIBE, self.topic.encode())
import time
time.sleep(0.1)
msg1 = 'message'
logger.info(msg1)
(topic, msg2) = sub.recv_multipart()
assert topic == b'zmq.INFO'
assert msg2 == (msg1 + "\n").encode("utf8")
logger.removeHandler(handler)
def test_root_topic(self):
logger, handler, sub = self.connect_handler()
handler.socket.bind(self.iface)
sub2 = sub.context.socket(zmq.SUB)
self.sockets.append(sub2)
sub2.connect(self.iface)
sub2.setsockopt(zmq.SUBSCRIBE, b'')
handler.root_topic = b'twoonly'
msg1 = 'ignored'
logger.info(msg1)
self.assertRaisesErrno(zmq.EAGAIN, sub.recv, zmq.NOBLOCK)
topic, msg2 = sub2.recv_multipart()
assert topic == b'twoonly.INFO'
assert msg2 == (msg1 + '\n').encode()
logger.removeHandler(handler)
def test_blank_root_topic(self):
logger, handler, sub_everything = self.connect_handler()
sub_everything.setsockopt(zmq.SUBSCRIBE, b'')
handler.socket.bind(self.iface)
sub_only_info = sub_everything.context.socket(zmq.SUB)
self.sockets.append(sub_only_info)
sub_only_info.connect(self.iface)
sub_only_info.setsockopt(zmq.SUBSCRIBE, b'INFO')
handler.setRootTopic(b'')
msg_debug = 'debug_message'
logger.debug(msg_debug)
self.assertRaisesErrno(zmq.EAGAIN, sub_only_info.recv, zmq.NOBLOCK)
topic, msg_debug_response = sub_everything.recv_multipart()
assert topic == b'DEBUG'
msg_info = 'info_message'
logger.info(msg_info)
topic, msg_info_response_everything = sub_everything.recv_multipart()
assert topic == b'INFO'
topic, msg_info_response_onlyinfo = sub_only_info.recv_multipart()
assert topic == b'INFO'
assert msg_info_response_everything == msg_info_response_onlyinfo
logger.removeHandler(handler)
def test_unicode_message(self):
logger, handler, sub = self.connect_handler()
base_topic = (self.topic + '.INFO').encode()
for msg, expected in [
('hello', [base_topic, b'hello\n']),
('héllo', [base_topic, 'héllo\n'.encode()]),
('tøpic::héllo', [base_topic + '.tøpic'.encode(), 'héllo\n'.encode()]),
]:
logger.info(msg)
received = sub.recv_multipart()
assert received == expected
logger.removeHandler(handler)
def test_set_info_formatter_via_property(self):
logger, handler, sub = self.connect_handler()
handler.formatters[logging.INFO] = logging.Formatter("%(message)s UNITTEST\n")
handler.socket.bind(self.iface)
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
logger.info('info message')
topic, msg = sub.recv_multipart()
assert msg == b'info message UNITTEST\n'
logger.removeHandler(handler)
def test_custom_global_formatter(self):
logger, handler, sub = self.connect_handler()
formatter = logging.Formatter("UNITTEST %(message)s")
handler.setFormatter(formatter)
handler.socket.bind(self.iface)
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
logger.info('info message')
topic, msg = sub.recv_multipart()
assert msg == b'UNITTEST info message'
logger.debug('debug message')
topic, msg = sub.recv_multipart()
assert msg == b'UNITTEST debug message'
logger.removeHandler(handler)
def test_custom_debug_formatter(self):
logger, handler, sub = self.connect_handler()
formatter = logging.Formatter("UNITTEST DEBUG %(message)s")
handler.setFormatter(formatter, logging.DEBUG)
handler.socket.bind(self.iface)
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
logger.info('info message')
topic, msg = sub.recv_multipart()
assert msg == b'info message\n'
logger.debug('debug message')
topic, msg = sub.recv_multipart()
assert msg == b'UNITTEST DEBUG debug message'
logger.removeHandler(handler)
def test_custom_message_type(self):
class Message:
def __init__(self, msg: str):
self.msg = msg
def __str__(self) -> str:
return self.msg
logger, handler, sub = self.connect_handler()
msg = "hello"
logger.info(Message(msg))
topic, received = sub.recv_multipart()
assert topic == b'zmq.INFO'
assert received == b'hello\n'
logger.removeHandler(handler)

View File

@@ -0,0 +1,370 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import gc
import sys
try:
from sys import getrefcount
except ImportError:
grc = None
else:
grc = getrefcount
import time
import zmq
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest, skip_pypy
# some useful constants:
x = b'x'
if grc:
rc0 = grc(x)
v = memoryview(x)
view_rc = grc(x) - rc0
def await_gc(obj, rc):
"""wait for refcount on an object to drop to an expected value
Necessary because of the zero-copy gc thread,
which can take some time to receive its DECREF message.
"""
# count refs for this function
if sys.version_info < (3, 11):
my_refs = 2
else:
my_refs = 1
for i in range(50):
# rc + 2 because of the refs in this function
if grc(obj) <= rc + my_refs:
return
time.sleep(0.05)
class TestFrame(BaseZMQTestCase):
def tearDown(self):
super().tearDown()
for i in range(3):
gc.collect()
@skip_pypy
def test_above_30(self):
"""Message above 30 bytes are never copied by 0MQ."""
for i in range(5, 16): # 32, 64,..., 65536
s = (2**i) * x
rc = grc(s)
m = zmq.Frame(s, copy=False)
assert grc(s) == rc + 2
del m
await_gc(s, rc)
assert grc(s) == rc
del s
def test_str(self):
"""Test the str representations of the Frames."""
for i in range(16):
s = (2**i) * x
m = zmq.Frame(s)
m_str = str(m)
m_str_b = m_str.encode()
assert s == m_str_b
def test_bytes(self):
"""Test the Frame.bytes property."""
for i in range(1, 16):
s = (2**i) * x
m = zmq.Frame(s)
b = m.bytes
assert s == m.bytes
if not PYPY:
# check that it copies
assert b is not s
# check that it copies only once
assert b is m.bytes
def test_unicode(self):
"""Test the unicode representations of the Frames."""
s = 'asdf'
self.assertRaises(TypeError, zmq.Frame, s)
for i in range(16):
s = (2**i) * '§'
m = zmq.Frame(s.encode('utf8'))
assert s == m.bytes.decode('utf8')
def test_len(self):
"""Test the len of the Frames."""
for i in range(16):
s = (2**i) * x
m = zmq.Frame(s)
assert len(s) == len(m)
@skip_pypy
def test_lifecycle1(self):
"""Run through a ref counting cycle with a copy."""
for i in range(5, 16): # 32, 64,..., 65536
s = (2**i) * x
rc = rc_0 = grc(s)
m = zmq.Frame(s, copy=False)
rc += 2
assert grc(s) == rc
m2 = copy.copy(m)
rc += 1
assert grc(s) == rc
# no increase in refcount for accessing buffer
# which references m2 directly
buf = m2.buffer
assert grc(s) == rc
assert s == str(m).encode()
assert s == bytes(m2)
assert s == m.bytes
assert s == bytes(buf)
# assert s is str(m)
# assert s is str(m2)
del m2
assert grc(s) == rc
# buf holds direct reference to m2 which holds
del buf
rc -= 1
assert grc(s) == rc
del m
rc -= 2
await_gc(s, rc)
assert grc(s) == rc
assert rc == rc_0
del s
@skip_pypy
def test_lifecycle2(self):
"""Run through a different ref counting cycle with a copy."""
for i in range(5, 16): # 32, 64,..., 65536
s = (2**i) * x
rc = rc_0 = grc(s)
m = zmq.Frame(s, copy=False)
rc += 2
assert grc(s) == rc
m2 = copy.copy(m)
rc += 1
assert grc(s) == rc
# no increase in refcount for accessing buffer
# which references m directly
buf = m.buffer
assert grc(s) == rc
assert s == str(m).encode()
assert s == bytes(m2)
assert s == m2.bytes
assert s == m.bytes
assert s == bytes(buf)
# assert s is str(m)
# assert s is str(m2)
del buf
assert grc(s) == rc
del m
rc -= 1
assert grc(s) == rc
del m2
rc -= 2
await_gc(s, rc)
assert grc(s) == rc
assert rc == rc_0
del s
def test_tracker(self):
m = zmq.Frame(b'asdf', copy=False, track=True)
assert not m.tracker.done
pm = zmq.MessageTracker(m)
assert not pm.done
del m
for i in range(3):
gc.collect()
for i in range(10):
if pm.done:
break
time.sleep(0.1)
assert pm.done
def test_no_tracker(self):
m = zmq.Frame(b'asdf', track=False)
assert m.tracker == None
m2 = copy.copy(m)
assert m2.tracker == None
self.assertRaises(ValueError, zmq.MessageTracker, m)
def test_multi_tracker(self):
m = zmq.Frame(b'asdf', copy=False, track=True)
m2 = zmq.Frame(b'whoda', copy=False, track=True)
mt = zmq.MessageTracker(m, m2)
assert not m.tracker.done
assert not mt.done
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
del m
for i in range(3):
gc.collect()
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
assert not mt.done
del m2
for i in range(3):
gc.collect()
assert mt.wait(0.1) is None
assert mt.done
def test_buffer_in(self):
"""test using a buffer as input"""
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
zmq.Frame(memoryview(ins))
def test_bad_buffer_in(self):
"""test using a bad object"""
self.assertRaises(TypeError, zmq.Frame, 5)
self.assertRaises(TypeError, zmq.Frame, object())
def test_buffer_out(self):
"""receiving buffered output"""
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
m = zmq.Frame(ins)
outb = m.buffer
assert isinstance(outb, memoryview)
assert outb is m.buffer
assert m.buffer is m.buffer
def test_memoryview_shape(self):
"""memoryview shape info"""
data = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
n = len(data)
f = zmq.Frame(data)
view1 = f.buffer
assert view1.ndim == 1
assert view1.shape == (n,)
assert view1.tobytes() == data
view2 = memoryview(f)
assert view2.ndim == 1
assert view2.shape == (n,)
assert view2.tobytes() == data
def test_multisend(self):
"""ensure that a message remains intact after multiple sends"""
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
s = b"message"
m = zmq.Frame(s)
assert s == m.bytes
a.send(m, copy=False)
time.sleep(0.1)
assert s == m.bytes
a.send(m, copy=False)
time.sleep(0.1)
assert s == m.bytes
a.send(m, copy=True)
time.sleep(0.1)
assert s == m.bytes
a.send(m, copy=True)
time.sleep(0.1)
assert s == m.bytes
for i in range(4):
r = b.recv()
assert s == r
assert s == m.bytes
def test_memoryview(self):
"""test messages from memoryview"""
s = b'carrotjuice'
memoryview(s)
m = zmq.Frame(s)
buf = m.buffer
s2 = buf.tobytes()
assert s2 == s
assert m.bytes == s
def test_noncopying_recv(self):
"""check for clobbering message buffers"""
null = b'\0' * 64
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
for i in range(32):
# try a few times
sb.send(null, copy=False)
m = sa.recv(copy=False)
mb = m.bytes
# buf = memoryview(m)
buf = m.buffer
del m
for i in range(5):
ff = b'\xff' * (40 + i * 10)
sb.send(ff, copy=False)
m2 = sa.recv(copy=False)
b = buf.tobytes()
assert b == null
assert mb == null
assert m2.bytes == ff
assert type(m2.bytes) is bytes
def test_noncopying_memoryview(self):
"""test non-copying memmoryview messages"""
null = b'\0' * 64
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
for i in range(32):
# try a few times
sb.send(memoryview(null), copy=False)
m = sa.recv(copy=False)
buf = memoryview(m)
for i in range(5):
ff = b'\xff' * (40 + i * 10)
sb.send(memoryview(ff), copy=False)
m2 = sa.recv(copy=False)
buf2 = memoryview(m2)
assert buf.tobytes() == null
assert not buf.readonly
assert buf2.tobytes() == ff
assert not buf2.readonly
assert type(buf) is memoryview
def test_buffer_numpy(self):
"""test non-copying numpy array messages"""
try:
import numpy
from numpy.testing import assert_array_equal
except ImportError:
raise SkipTest("requires numpy")
rand = numpy.random.randint
shapes = [rand(2, 5) for i in range(5)]
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
dtypes = [int, float, '>i4', 'B']
for i in range(1, len(shapes) + 1):
shape = shapes[:i]
for dt in dtypes:
A = numpy.empty(shape, dtype=dt)
a.send(A, copy=False)
msg = b.recv(copy=False)
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
assert_array_equal(A, B)
A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')])
A['a'] = 1024
A['b'] = 1e9
A['c'] = 'hello there'
a.send(A, copy=False)
msg = b.recv(copy=False)
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
assert_array_equal(A, B)
@skip_pypy
def test_frame_more(self):
"""test Frame.more attribute"""
frame = zmq.Frame(b"hello")
assert not frame.more
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
sa.send_multipart([b'hi', b'there'])
frame = self.recv(sb, copy=False)
assert frame.more
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
assert frame.get(zmq.MORE)
frame = self.recv(sb, copy=False)
assert not frame.more
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
assert not frame.get(zmq.MORE)

View File

@@ -0,0 +1,94 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
import zmq.asyncio
from zmq.tests import require_zmq_4
from zmq.utils.monitor import recv_monitor_message
pytestmark = require_zmq_4
import pytest
@pytest.fixture(params=["zmq", "asyncio"])
def Context(request, event_loop):
if request.param == "asyncio":
return zmq.asyncio.Context
else:
return zmq.Context
async def test_monitor(context, socket):
"""Test monitoring interface for sockets."""
s_rep = socket(zmq.REP)
s_req = socket(zmq.REQ)
s_req.bind("tcp://127.0.0.1:6666")
# try monitoring the REP socket
s_rep.monitor(
"inproc://monitor.rep",
zmq.EVENT_CONNECT_DELAYED | zmq.EVENT_CONNECTED | zmq.EVENT_MONITOR_STOPPED,
)
# create listening socket for monitor
s_event = socket(zmq.PAIR)
s_event.connect("inproc://monitor.rep")
s_event.linger = 0
# test receive event for connect event
s_rep.connect("tcp://127.0.0.1:6666")
m = recv_monitor_message(s_event)
if isinstance(context, zmq.asyncio.Context):
m = await m
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
assert m['endpoint'] == b"tcp://127.0.0.1:6666"
# test receive event for connected event
m = recv_monitor_message(s_event)
if isinstance(context, zmq.asyncio.Context):
m = await m
assert m['event'] == zmq.EVENT_CONNECTED
assert m['endpoint'] == b"tcp://127.0.0.1:6666"
# test monitor can be disabled.
s_rep.disable_monitor()
m = recv_monitor_message(s_event)
if isinstance(context, zmq.asyncio.Context):
m = await m
assert m['event'] == zmq.EVENT_MONITOR_STOPPED
async def test_monitor_repeat(context, socket, sockets):
s = socket(zmq.PULL)
m = s.get_monitor_socket()
sockets.append(m)
m2 = s.get_monitor_socket()
assert m is m2
s.disable_monitor()
evt = recv_monitor_message(m)
if isinstance(context, zmq.asyncio.Context):
evt = await evt
assert evt['event'] == zmq.EVENT_MONITOR_STOPPED
m.close()
s.close()
async def test_monitor_connected(context, socket, sockets):
"""Test connected monitoring socket."""
s_rep = socket(zmq.REP)
s_req = socket(zmq.REQ)
s_req.bind("tcp://127.0.0.1:6667")
# try monitoring the REP socket
# create listening socket for monitor
s_event = s_rep.get_monitor_socket()
s_event.linger = 0
sockets.append(s_event)
# test receive event for connect event
s_rep.connect("tcp://127.0.0.1:6667")
m = recv_monitor_message(s_event)
if isinstance(context, zmq.asyncio.Context):
m = await m
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
assert m['endpoint'] == b"tcp://127.0.0.1:6667"
# test receive event for connected event
m = recv_monitor_message(s_event)
if isinstance(context, zmq.asyncio.Context):
m = await m
assert m['event'] == zmq.EVENT_CONNECTED
assert m['endpoint'] == b"tcp://127.0.0.1:6667"

View File

@@ -0,0 +1,235 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import threading
import time
import zmq
from zmq import devices
from zmq.tests import PYPY, BaseZMQTestCase
if PYPY or zmq.zmq_version_info() >= (4, 1):
# cleanup of shared Context doesn't work on PyPy
# there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052)
devices.Device.context_factory = zmq.Context
class TestMonitoredQueue(BaseZMQTestCase):
def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'):
self.device = devices.ThreadMonitoredQueue(
zmq.PAIR, zmq.PAIR, zmq.PUB, in_prefix, out_prefix
)
alice = self.context.socket(zmq.PAIR)
bob = self.context.socket(zmq.PAIR)
mon = self.context.socket(zmq.SUB)
aport = alice.bind_to_random_port('tcp://127.0.0.1')
bport = bob.bind_to_random_port('tcp://127.0.0.1')
mport = mon.bind_to_random_port('tcp://127.0.0.1')
mon.setsockopt(zmq.SUBSCRIBE, mon_sub)
self.device.connect_in("tcp://127.0.0.1:%i" % aport)
self.device.connect_out("tcp://127.0.0.1:%i" % bport)
self.device.connect_mon("tcp://127.0.0.1:%i" % mport)
self.device.start()
time.sleep(0.2)
try:
# this is currently necessary to ensure no dropped monitor messages
# see LIBZMQ-248 for more info
mon.recv_multipart(zmq.NOBLOCK)
except zmq.ZMQError:
pass
self.sockets.extend([alice, bob, mon])
return alice, bob, mon
def teardown_device(self):
# spawn term in a background thread
for i in range(50):
# wait for device._context to be populated
context = getattr(self.device, "_context", None)
if context is not None:
break
time.sleep(0.1)
if context is not None:
t = threading.Thread(target=self.device._context.term, daemon=True)
t.start()
for socket in self.sockets:
socket.close()
if context is not None:
t.join(timeout=5)
self.device.join(timeout=5)
def test_reply(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
bobs = self.recv_multipart(bob)
assert alices == bobs
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
self.teardown_device()
def test_queue(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
self.teardown_device()
def test_monitor(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'in'] + bobs == mons
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
mons = self.recv_multipart(mon)
assert [b'in'] + alices2 == mons
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'in'] + alices3 == mons
mons = self.recv_multipart(mon)
assert [b'out'] + bobs == mons
self.teardown_device()
def test_prefix(self):
alice, bob, mon = self.build_device(b"", b'foo', b'bar')
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'foo'] + bobs == mons
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
mons = self.recv_multipart(mon)
assert [b'foo'] + alices2 == mons
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'foo'] + alices3 == mons
mons = self.recv_multipart(mon)
assert [b'bar'] + bobs == mons
self.teardown_device()
def test_monitor_subscribe(self):
alice, bob, mon = self.build_device(b"out")
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'out'] + bobs == mons
self.teardown_device()
def test_router_router(self):
"""test router-router MQ devices"""
dev = devices.ThreadMonitoredQueue(
zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out'
)
self.device = dev
dev.setsockopt_in(zmq.LINGER, 0)
dev.setsockopt_out(zmq.LINGER, 0)
dev.setsockopt_mon(zmq.LINGER, 0)
porta = dev.bind_in_to_random_port('tcp://127.0.0.1')
portb = dev.bind_out_to_random_port('tcp://127.0.0.1')
a = self.context.socket(zmq.DEALER)
a.identity = b'a'
b = self.context.socket(zmq.DEALER)
b.identity = b'b'
self.sockets.extend([a, b])
a.connect('tcp://127.0.0.1:%i' % porta)
b.connect('tcp://127.0.0.1:%i' % portb)
dev.start()
time.sleep(1)
if zmq.zmq_version_info() >= (3, 1, 0):
# flush erroneous poll state, due to LIBZMQ-280
ping_msg = [b'ping', b'pong']
for s in (a, b):
s.send_multipart(ping_msg)
try:
s.recv(zmq.NOBLOCK)
except zmq.ZMQError:
pass
msg = [b'hello', b'there']
a.send_multipart([b'b'] + msg)
bmsg = self.recv_multipart(b)
assert bmsg == [b'a'] + msg
b.send_multipart(bmsg)
amsg = self.recv_multipart(a)
assert amsg == [b'b'] + msg
self.teardown_device()
def test_default_mq_args(self):
self.device = dev = devices.ThreadMonitoredQueue(
zmq.ROUTER, zmq.DEALER, zmq.PUB
)
dev.setsockopt_in(zmq.LINGER, 0)
dev.setsockopt_out(zmq.LINGER, 0)
dev.setsockopt_mon(zmq.LINGER, 0)
# this will raise if default args are wrong
dev.start()
self.teardown_device()
def test_mq_check_prefix(self):
ins = self.context.socket(zmq.ROUTER)
outs = self.context.socket(zmq.DEALER)
mons = self.context.socket(zmq.PUB)
self.sockets.extend([ins, outs, mons])
ins = 'in'
outs = 'out'
self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons)

View File

@@ -0,0 +1,34 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
class TestMultipart(BaseZMQTestCase):
def test_router_dealer(self):
router, dealer = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
msg1 = b'message1'
dealer.send(msg1)
self.recv(router)
more = router.rcvmore
assert more == True
msg2 = self.recv(router)
assert msg1 == msg2
more = router.rcvmore
assert more == False
def test_basic_multipart(self):
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
msg = [b'hi', b'there', b'b']
a.send_multipart(msg)
recvd = b.recv_multipart()
assert msg == recvd
if have_gevent:
class TestMultipartGreen(GreenTest, TestMultipart):
pass

View File

@@ -0,0 +1,73 @@
"""
Test our typing with mypy
"""
import os
import sys
from subprocess import PIPE, STDOUT, Popen
import pytest
import zmq
pytest.importorskip("mypy")
zmq_dir = os.path.dirname(zmq.__file__)
def resolve_repo_dir(path):
"""Resolve a dir in the repo
Resolved relative to zmq dir
fallback on CWD (e.g. test run from repo, zmq installed, not -e)
"""
resolved_path = os.path.join(os.path.dirname(zmq_dir), path)
# fallback on CWD
if not os.path.exists(resolved_path):
resolved_path = path
return resolved_path
examples_dir = resolve_repo_dir("examples")
mypy_dir = resolve_repo_dir("mypy_tests")
def run_mypy(*mypy_args):
"""Run mypy for a path
Captures output and reports it on errors
"""
p = Popen(
[sys.executable, "-m", "mypy"] + list(mypy_args), stdout=PIPE, stderr=STDOUT
)
o, _ = p.communicate()
out = o.decode("utf8", "replace")
print(out)
assert p.returncode == 0, out
if os.path.exists(examples_dir):
examples = [
d
for d in os.listdir(examples_dir)
if os.path.isdir(os.path.join(examples_dir, d))
]
@pytest.mark.skipif(
not os.path.exists(examples_dir), reason="only test from examples directory"
)
@pytest.mark.parametrize("example", examples)
def test_mypy_example(example):
example_dir = os.path.join(examples_dir, example)
run_mypy("--disallow-untyped-calls", example_dir)
if os.path.exists(mypy_dir):
mypy_tests = [p for p in os.listdir(mypy_dir) if p.endswith(".py")]
@pytest.mark.parametrize("filename", mypy_tests)
def test_mypy(filename):
run_mypy("--disallow-untyped-calls", os.path.join(mypy_dir, filename))

View File

@@ -0,0 +1,52 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
x = b' '
class TestPair(BaseZMQTestCase):
def test_basic(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
msg1 = b'message1'
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
def test_multiple(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
for i in range(10):
msg = i * x
s1.send(msg)
for i in range(10):
msg = i * x
s2.send(msg)
for i in range(10):
msg = s1.recv()
assert msg == i * x
for i in range(10):
msg = s2.recv()
assert msg == i * x
def test_json(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
o = dict(a=10, b=list(range(10)))
self.ping_pong_json(s1, s2, o)
def test_pyobj(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
o = dict(a=10, b=range(10))
self.ping_pong_pyobj(s1, s2, o)
if have_gevent:
class TestReqRepGreen(GreenTest, TestPair):
pass

View File

@@ -0,0 +1,238 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import os
import sys
import time
from pytest import mark
import zmq
from zmq.tests import GreenTest, PollZMQTestCase, have_gevent
def wait():
time.sleep(0.25)
class TestPoll(PollZMQTestCase):
Poller = zmq.Poller
def test_pair(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
# Sleep to allow sockets to connect.
wait()
poller = self.Poller()
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
poller.register(s2, zmq.POLLIN | zmq.POLLOUT)
# Poll result should contain both sockets
socks = dict(poller.poll())
# Now make sure that both are send ready.
assert socks[s1] == zmq.POLLOUT
assert socks[s2] == zmq.POLLOUT
# Now do a send on both, wait and test for zmq.POLLOUT|zmq.POLLIN
s1.send(b'msg1')
s2.send(b'msg2')
wait()
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT | zmq.POLLIN
assert socks[s2] == zmq.POLLOUT | zmq.POLLIN
# Make sure that both are in POLLOUT after recv.
s1.recv()
s2.recv()
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT
assert socks[s2] == zmq.POLLOUT
poller.unregister(s1)
poller.unregister(s2)
def test_reqrep(self):
s1, s2 = self.create_bound_pair(zmq.REP, zmq.REQ)
# Sleep to allow sockets to connect.
wait()
poller = self.Poller()
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
poller.register(s2, zmq.POLLIN | zmq.POLLOUT)
# Make sure that s1 is in state 0 and s2 is in POLLOUT
socks = dict(poller.poll())
assert s1 not in socks
assert socks[s2] == zmq.POLLOUT
# Make sure that s2 goes immediately into state 0 after send.
s2.send(b'msg1')
socks = dict(poller.poll())
assert s2 not in socks
# Make sure that s1 goes into POLLIN state after a time.sleep().
time.sleep(0.5)
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLIN
# Make sure that s1 goes into POLLOUT after recv.
s1.recv()
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT
# Make sure s1 goes into state 0 after send.
s1.send(b'msg2')
socks = dict(poller.poll())
assert s1 not in socks
# Wait and then see that s2 is in POLLIN.
time.sleep(0.5)
socks = dict(poller.poll())
assert socks[s2] == zmq.POLLIN
# Make sure that s2 is in POLLOUT after recv.
s2.recv()
socks = dict(poller.poll())
assert socks[s2] == zmq.POLLOUT
poller.unregister(s1)
poller.unregister(s2)
def test_no_events(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
poller = self.Poller()
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
poller.register(s2, 0)
assert s1 in poller
assert s2 not in poller
poller.register(s1, 0)
assert s1 not in poller
def test_pubsub(self):
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
s2.setsockopt(zmq.SUBSCRIBE, b'')
# Sleep to allow sockets to connect.
wait()
poller = self.Poller()
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
poller.register(s2, zmq.POLLIN)
# Now make sure that both are send ready.
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT
assert s2 not in socks
# Make sure that s1 stays in POLLOUT after a send.
s1.send(b'msg1')
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT
# Make sure that s2 is POLLIN after waiting.
wait()
socks = dict(poller.poll())
assert socks[s2] == zmq.POLLIN
# Make sure that s2 goes into 0 after recv.
s2.recv()
socks = dict(poller.poll())
assert s2 not in socks
poller.unregister(s1)
poller.unregister(s2)
@mark.skipif(sys.platform.startswith('win'), reason='Windows')
def test_raw(self):
r, w = os.pipe()
r = os.fdopen(r, 'rb')
w = os.fdopen(w, 'wb')
p = self.Poller()
p.register(r, zmq.POLLIN)
socks = dict(p.poll(1))
assert socks == {}
w.write(b'x')
w.flush()
socks = dict(p.poll(1))
assert socks == {r.fileno(): zmq.POLLIN}
w.close()
r.close()
@mark.flaky(reruns=3)
def test_timeout(self):
"""make sure Poller.poll timeout has the right units (milliseconds)."""
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
poller = self.Poller()
poller.register(s1, zmq.POLLIN)
tic = time.perf_counter()
poller.poll(0.005)
toc = time.perf_counter()
toc - tic < 0.5
tic = time.perf_counter()
poller.poll(50)
toc = time.perf_counter()
assert toc - tic < 0.5
assert toc - tic > 0.01
tic = time.perf_counter()
poller.poll(500)
toc = time.perf_counter()
assert toc - tic < 1
assert toc - tic > 0.1
class TestSelect(PollZMQTestCase):
def test_pair(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
# Sleep to allow sockets to connect.
wait()
rlist, wlist, xlist = zmq.select([s1, s2], [s1, s2], [s1, s2])
assert s1 in wlist
assert s2 in wlist
assert s1 not in rlist
assert s2 not in rlist
@mark.flaky(reruns=3)
def test_timeout(self):
"""make sure select timeout has the right units (seconds)."""
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
tic = time.perf_counter()
r, w, x = zmq.select([s1, s2], [], [], 0.005)
toc = time.perf_counter()
assert toc - tic < 1
assert toc - tic > 0.001
tic = time.perf_counter()
r, w, x = zmq.select([s1, s2], [], [], 0.25)
toc = time.perf_counter()
assert toc - tic < 1
assert toc - tic > 0.1
if have_gevent:
import gevent
from zmq import green as gzmq
class TestPollGreen(GreenTest, TestPoll):
Poller = gzmq.Poller
def test_wakeup(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
poller = self.Poller()
poller.register(s2, zmq.POLLIN)
tic = time.perf_counter()
r = gevent.spawn(lambda: poller.poll(10000))
s = gevent.spawn(lambda: s1.send(b'msg1'))
r.join()
toc = time.perf_counter()
assert toc - tic < 1
def test_socket_poll(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
tic = time.perf_counter()
r = gevent.spawn(lambda: s2.poll(10000))
s = gevent.spawn(lambda: s1.send(b'msg1'))
r.join()
toc = time.perf_counter()
assert toc - tic < 1

View File

@@ -0,0 +1,95 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import struct
import time
import zmq
from zmq import devices
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest
if PYPY:
# cleanup of shared Context doesn't work on PyPy
devices.Device.context_factory = zmq.Context
class TestProxySteerable(BaseZMQTestCase):
def test_proxy_steerable(self):
if zmq.zmq_version_info() < (4, 1):
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
iface = 'tcp://127.0.0.1'
port = dev.bind_in_to_random_port(iface)
port2 = dev.bind_out_to_random_port(iface)
port3 = dev.bind_mon_to_random_port(iface)
port4 = dev.bind_ctrl_to_random_port(iface)
dev.start()
time.sleep(0.25)
msg = b'hello'
push = self.context.socket(zmq.PUSH)
push.connect("%s:%i" % (iface, port))
pull = self.context.socket(zmq.PULL)
pull.connect("%s:%i" % (iface, port2))
mon = self.context.socket(zmq.PULL)
mon.connect("%s:%i" % (iface, port3))
ctrl = self.context.socket(zmq.PAIR)
ctrl.connect("%s:%i" % (iface, port4))
push.send(msg)
self.sockets.extend([push, pull, mon, ctrl])
assert msg == self.recv(pull)
assert msg == self.recv(mon)
ctrl.send(b'TERMINATE')
dev.join()
def test_proxy_steerable_bind_to_random_with_args(self):
if zmq.zmq_version_info() < (4, 1):
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
iface = 'tcp://127.0.0.1'
ports = []
min, max = 5000, 5050
ports.extend(
[
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
dev.bind_ctrl_to_random_port(iface, min_port=min, max_port=max),
]
)
for port in ports:
if port < min or port > max:
self.fail('Unexpected port number: %i' % port)
def test_proxy_steerable_statistics(self):
if zmq.zmq_version_info() < (4, 3):
raise SkipTest("STATISTICS only in libzmq >= 4.3")
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
iface = 'tcp://127.0.0.1'
port = dev.bind_in_to_random_port(iface)
port2 = dev.bind_out_to_random_port(iface)
port3 = dev.bind_mon_to_random_port(iface)
port4 = dev.bind_ctrl_to_random_port(iface)
dev.start()
time.sleep(0.25)
msg = b'hello'
push = self.context.socket(zmq.PUSH)
push.connect("%s:%i" % (iface, port))
pull = self.context.socket(zmq.PULL)
pull.connect("%s:%i" % (iface, port2))
mon = self.context.socket(zmq.PULL)
mon.connect("%s:%i" % (iface, port3))
ctrl = self.context.socket(zmq.PAIR)
ctrl.connect("%s:%i" % (iface, port4))
push.send(msg)
self.sockets.extend([push, pull, mon, ctrl])
assert msg == self.recv(pull)
assert msg == self.recv(mon)
ctrl.send(b'STATISTICS')
stats = self.recv_multipart(ctrl)
stats_int = [struct.unpack("=Q", x)[0] for x in stats]
assert 1 == stats_int[0]
assert len(msg) == stats_int[1]
assert 1 == stats_int[6]
assert len(msg) == stats_int[7]
ctrl.send(b'TERMINATE')
dev.join()

View File

@@ -0,0 +1,40 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
class TestPubSub(BaseZMQTestCase):
pass
# We are disabling this test while an issue is being resolved.
def test_basic(self):
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
s2.setsockopt(zmq.SUBSCRIBE, b'')
time.sleep(0.1)
msg1 = b'message'
s1.send(msg1)
msg2 = s2.recv() # This is blocking!
assert msg1 == msg2
def test_topic(self):
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
s2.setsockopt(zmq.SUBSCRIBE, b'x')
time.sleep(0.1)
msg1 = b'message'
s1.send(msg1)
self.assertRaisesErrno(zmq.EAGAIN, s2.recv, zmq.NOBLOCK)
msg1 = b'xmessage'
s1.send(msg1)
msg2 = s2.recv()
assert msg1 == msg2
if have_gevent:
class TestPubSubGreen(GreenTest, TestPubSub):
pass

View File

@@ -0,0 +1,61 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
class TestReqRep(BaseZMQTestCase):
def test_basic(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
msg1 = b'message 1'
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
def test_multiple(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
for i in range(10):
msg1 = i * b' '
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
def test_bad_send_recv(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
if zmq.zmq_version() != '2.1.8':
# this doesn't work on 2.1.8
for copy in (True, False):
self.assertRaisesErrno(zmq.EFSM, s1.recv, copy=copy)
self.assertRaisesErrno(zmq.EFSM, s2.send, b'asdf', copy=copy)
# I have to have this or we die on an Abort trap.
msg1 = b'asdf'
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
def test_json(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
o = dict(a=10, b=list(range(10)))
self.ping_pong_json(s1, s2, o)
def test_pyobj(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
o = dict(a=10, b=range(10))
self.ping_pong_pyobj(s1, s2, o)
def test_large_msg(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
msg1 = 10000 * b'X'
for i in range(10):
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
if have_gevent:
class TestReqRepGreen(GreenTest, TestReqRep):
pass

View File

@@ -0,0 +1,94 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import signal
import time
from threading import Thread
from pytest import mark
import zmq
from zmq.tests import BaseZMQTestCase, SkipTest
# Partially based on EINTRBaseTest from CPython 3.5 eintr_tester
class TestEINTRSysCall(BaseZMQTestCase):
"""Base class for EINTR tests."""
# delay for initial signal delivery
signal_delay = 0.1
# timeout for tests. Must be > signal_delay
timeout = 0.25
timeout_ms = int(timeout * 1e3)
def alarm(self, t=None):
"""start a timer to fire only once
like signal.alarm, but with better resolution than integer seconds.
"""
if not hasattr(signal, 'setitimer'):
raise SkipTest('EINTR tests require setitimer')
if t is None:
t = self.signal_delay
self.timer_fired = False
self.orig_handler = signal.signal(signal.SIGALRM, self.stop_timer)
# signal_period ignored, since only one timer event is allowed to fire
signal.setitimer(signal.ITIMER_REAL, t, 1000)
def stop_timer(self, *args):
self.timer_fired = True
signal.setitimer(signal.ITIMER_REAL, 0, 0)
signal.signal(signal.SIGALRM, self.orig_handler)
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_retry_recv(self):
pull = self.socket(zmq.PULL)
pull.rcvtimeo = self.timeout_ms
self.alarm()
self.assertRaises(zmq.Again, pull.recv)
assert self.timer_fired
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_retry_send(self):
push = self.socket(zmq.PUSH)
push.sndtimeo = self.timeout_ms
self.alarm()
self.assertRaises(zmq.Again, push.send, b'buf')
assert self.timer_fired
@mark.flaky(reruns=3)
def test_retry_poll(self):
x, y = self.create_bound_pair()
poller = zmq.Poller()
poller.register(x, zmq.POLLIN)
self.alarm()
def send():
time.sleep(2 * self.signal_delay)
y.send(b'ping')
t = Thread(target=send)
t.start()
evts = dict(poller.poll(2 * self.timeout_ms))
t.join()
assert x in evts
assert self.timer_fired
x.recv()
def test_retry_term(self):
push = self.socket(zmq.PUSH)
push.linger = self.timeout_ms
push.connect('tcp://127.0.0.1:5555')
push.send(b'ping')
time.sleep(0.1)
self.alarm()
self.context.destroy()
assert self.timer_fired
assert self.context.closed
def test_retry_getsockopt(self):
raise SkipTest("TODO: find a way to interrupt getsockopt")
def test_retry_setsockopt(self):
raise SkipTest("TODO: find a way to interrupt setsockopt")

View File

@@ -0,0 +1,238 @@
"""Test libzmq security (libzmq >= 3.3.0)"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import contextlib
import os
import time
from threading import Thread
import zmq
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest
from zmq.utils import z85
USER = b"admin"
PASS = b"password"
class TestSecurity(BaseZMQTestCase):
def setUp(self):
if zmq.zmq_version_info() < (4, 0):
raise SkipTest("security is new in libzmq 4.0")
try:
zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("security requires libzmq to be built with CURVE support")
super().setUp()
def zap_handler(self):
socket = self.context.socket(zmq.REP)
socket.bind("inproc://zeromq.zap.01")
try:
msg = self.recv_multipart(socket)
version, sequence, domain, address, identity, mechanism = msg[:6]
if mechanism == b'PLAIN':
username, password = msg[6:]
elif mechanism == b'CURVE':
msg[6]
assert version == b"1.0"
assert identity == b"IDENT"
reply = [version, sequence]
if (
mechanism == b'CURVE'
or (mechanism == b'PLAIN' and username == USER and password == PASS)
or (mechanism == b'NULL')
):
reply.extend(
[
b"200",
b"OK",
b"anonymous",
b"\5Hello\0\0\0\5World",
]
)
else:
reply.extend(
[
b"400",
b"Invalid username or password",
b"",
b"",
]
)
socket.send_multipart(reply)
finally:
socket.close()
@contextlib.contextmanager
def zap(self):
self.start_zap()
time.sleep(0.5) # allow time for the Thread to start
try:
yield
finally:
self.stop_zap()
def start_zap(self):
self.zap_thread = Thread(target=self.zap_handler)
self.zap_thread.start()
def stop_zap(self):
self.zap_thread.join()
def bounce(self, server, client, test_metadata=True):
msg = [os.urandom(64), os.urandom(64)]
client.send_multipart(msg)
frames = self.recv_multipart(server, copy=False)
recvd = list(map(lambda x: x.bytes, frames))
try:
if test_metadata and not PYPY:
for frame in frames:
assert frame.get('User-Id') == 'anonymous'
assert frame.get('Hello') == 'World'
assert frame['Socket-Type'] == 'DEALER'
except zmq.ZMQVersionError:
pass
assert recvd == msg
server.send_multipart(recvd)
msg2 = self.recv_multipart(client)
assert msg2 == msg
def test_null(self):
"""test NULL (default) security"""
server = self.socket(zmq.DEALER)
client = self.socket(zmq.DEALER)
assert client.MECHANISM == zmq.NULL
assert server.mechanism == zmq.NULL
assert client.plain_server == 0
assert server.plain_server == 0
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
self.bounce(server, client, False)
def test_plain(self):
"""test PLAIN authentication"""
server = self.socket(zmq.DEALER)
server.identity = b'IDENT'
client = self.socket(zmq.DEALER)
assert client.plain_username == b''
assert client.plain_password == b''
client.plain_username = USER
client.plain_password = PASS
assert client.getsockopt(zmq.PLAIN_USERNAME) == USER
assert client.getsockopt(zmq.PLAIN_PASSWORD) == PASS
assert client.plain_server == 0
assert server.plain_server == 0
server.plain_server = True
assert server.mechanism == zmq.PLAIN
assert client.mechanism == zmq.PLAIN
assert not client.plain_server
assert server.plain_server
with self.zap():
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
self.bounce(server, client)
def skip_plain_inauth(self):
"""test PLAIN failed authentication"""
server = self.socket(zmq.DEALER)
server.identity = b'IDENT'
client = self.socket(zmq.DEALER)
self.sockets.extend([server, client])
client.plain_username = USER
client.plain_password = b'incorrect'
server.plain_server = True
assert server.mechanism == zmq.PLAIN
assert client.mechanism == zmq.PLAIN
with self.zap():
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
client.send(b'ping')
server.rcvtimeo = 250
self.assertRaisesErrno(zmq.EAGAIN, server.recv)
def test_keypair(self):
"""test curve_keypair"""
try:
public, secret = zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("CURVE unsupported")
assert type(secret) == bytes
assert type(public) == bytes
assert len(secret) == 40
assert len(public) == 40
# verify that it is indeed Z85
bsecret, bpublic = (z85.decode(key) for key in (public, secret))
assert type(bsecret) == bytes
assert type(bpublic) == bytes
assert len(bsecret) == 32
assert len(bpublic) == 32
def test_curve_public(self):
"""test curve_public"""
try:
public, secret = zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("CURVE unsupported")
if zmq.zmq_version_info() < (4, 2):
raise SkipTest("curve_public is new in libzmq 4.2")
derived_public = zmq.curve_public(secret)
assert type(derived_public) == bytes
assert len(derived_public) == 40
# verify that it is indeed Z85
bpublic = z85.decode(derived_public)
assert type(bpublic) == bytes
assert len(bpublic) == 32
# verify that it is equal to the known public key
assert derived_public == public
def test_curve(self):
"""test CURVE encryption"""
server = self.socket(zmq.DEALER)
server.identity = b'IDENT'
client = self.socket(zmq.DEALER)
self.sockets.extend([server, client])
try:
server.curve_server = True
except zmq.ZMQError as e:
# will raise EINVAL if no CURVE support
if e.errno == zmq.EINVAL:
raise SkipTest("CURVE unsupported")
server_public, server_secret = zmq.curve_keypair()
client_public, client_secret = zmq.curve_keypair()
server.curve_secretkey = server_secret
server.curve_publickey = server_public
client.curve_serverkey = server_public
client.curve_publickey = client_public
client.curve_secretkey = client_secret
assert server.mechanism == zmq.CURVE
assert client.mechanism == zmq.CURVE
assert server.get(zmq.CURVE_SERVER) == True
assert client.get(zmq.CURVE_SERVER) == False
with self.zap():
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
self.bounce(server, client)

View File

@@ -0,0 +1,690 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import errno
import json
import os
import platform
import socket
import sys
import time
import warnings
from unittest import mock
import pytest
from pytest import mark
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, SkipTest, have_gevent, skip_pypy
pypy = platform.python_implementation().lower() == 'pypy'
windows = platform.platform().lower().startswith('windows')
on_ci = bool(os.environ.get('CI'))
# polling on windows is slow
POLL_TIMEOUT = 1000 if windows else 100
class TestSocket(BaseZMQTestCase):
def test_create(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
# Superluminal protocol not yet implemented
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a')
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a')
self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://')
s.close()
ctx.term()
def test_context_manager(self):
url = 'inproc://a'
with self.Context() as ctx:
with ctx.socket(zmq.PUSH) as a:
a.bind(url)
with ctx.socket(zmq.PULL) as b:
b.connect(url)
msg = b'hi'
a.send(msg)
rcvd = self.recv(b)
assert rcvd == msg
assert b.closed == True
assert a.closed == True
assert ctx.closed == True
def test_connectbind_context_managers(self):
url = 'inproc://a'
msg = b'hi'
with self.Context() as ctx:
# Test connect() context manager
with ctx.socket(zmq.PUSH) as a, ctx.socket(zmq.PULL) as b:
a.bind(url)
connect_context = b.connect(url)
assert f'connect={url!r}' in repr(connect_context)
with connect_context:
a.send(msg)
rcvd = self.recv(b)
assert rcvd == msg
# b should now be disconnected, so sending and receiving don't work
with pytest.raises(zmq.Again):
a.send(msg, flags=zmq.DONTWAIT)
with pytest.raises(zmq.Again):
b.recv(flags=zmq.DONTWAIT)
a.unbind(url)
# Test bind() context manager
with ctx.socket(zmq.PUSH) as a, ctx.socket(zmq.PULL) as b:
# unbind() just stops accepting of new connections, so we have to disconnect to test that
# unbind happened.
bind_context = a.bind(url)
assert f'bind={url!r}' in repr(bind_context)
with bind_context:
b.connect(url)
a.send(msg)
rcvd = self.recv(b)
assert rcvd == msg
b.disconnect(url)
b.connect(url)
# Since a is unbound from url, b is not connected to anything
with pytest.raises(zmq.Again):
a.send(msg, flags=zmq.DONTWAIT)
with pytest.raises(zmq.Again):
b.recv(flags=zmq.DONTWAIT)
_repr_cls = "zmq.Socket"
def test_repr(self):
with self.context.socket(zmq.PUSH) as s:
assert f'{self._repr_cls}(zmq.PUSH)' in repr(s)
assert 'closed' not in repr(s)
assert f'{self._repr_cls}(zmq.PUSH)' in repr(s)
assert 'closed' in repr(s)
def test_dir(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
assert 'send' in dir(s)
assert 'IDENTITY' in dir(s)
assert 'AFFINITY' in dir(s)
assert 'FD' in dir(s)
s.close()
ctx.term()
@mark.skipif(mock is None, reason="requires unittest.mock")
def test_mockable(self):
s = self.socket(zmq.SUB)
m = mock.Mock(spec=s)
s.close()
def test_bind_unicode(self):
s = self.socket(zmq.PUB)
p = s.bind_to_random_port("tcp://*")
def test_connect_unicode(self):
s = self.socket(zmq.PUB)
s.connect("tcp://127.0.0.1:5555")
def test_bind_to_random_port(self):
# Check that bind_to_random_port do not hide useful exception
ctx = self.Context()
s = ctx.socket(zmq.PUB)
# Invalid format
try:
s.bind_to_random_port('tcp:*')
except zmq.ZMQError as e:
assert e.errno == zmq.EINVAL
# Invalid protocol
try:
s.bind_to_random_port('rand://*')
except zmq.ZMQError as e:
assert e.errno == zmq.EPROTONOSUPPORT
s.close()
ctx.term()
def test_bind_connect_addr_error(self):
with self.socket(zmq.PUSH) as s:
url = "tcp://1.2.3.4.5:1234567"
with pytest.raises(zmq.ZMQError) as exc:
s.bind(url)
assert url in str(exc.value)
url = "noproc://no/such/file"
with pytest.raises(zmq.ZMQError) as exc:
s.connect(url)
assert url in str(exc.value)
def test_identity(self):
s = self.context.socket(zmq.PULL)
self.sockets.append(s)
ident = b'identity\0\0'
s.identity = ident
assert s.get(zmq.IDENTITY) == ident
def test_unicode_sockopts(self):
"""test setting/getting sockopts with unicode strings"""
topic = "tést"
p, s = self.create_bound_pair(zmq.PUB, zmq.SUB)
assert s.send_unicode == s.send_unicode
assert p.recv_unicode == p.recv_unicode
self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic)
self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic)
s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16')
self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic)
s.setsockopt_unicode(zmq.SUBSCRIBE, topic)
self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY)
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE)
identb = s.getsockopt(zmq.IDENTITY)
identu = identb.decode('utf16')
identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16')
assert identu == identu2
time.sleep(0.1) # wait for connection/subscription
p.send_unicode(topic, zmq.SNDMORE)
p.send_unicode(topic * 2, encoding='latin-1')
assert topic == s.recv_unicode()
assert topic * 2 == s.recv_unicode(encoding='latin-1')
def test_int_sockopts(self):
"test integer sockopts"
v = zmq.zmq_version_info()
if v < (3, 0):
default_hwm = 0
else:
default_hwm = 1000
p, s = self.create_bound_pair(zmq.PUB, zmq.SUB)
p.setsockopt(zmq.LINGER, 0)
assert p.getsockopt(zmq.LINGER) == 0
p.setsockopt(zmq.LINGER, -1)
assert p.getsockopt(zmq.LINGER) == -1
assert p.hwm == default_hwm
p.hwm = 11
assert p.hwm == 11
# p.setsockopt(zmq.EVENTS, zmq.POLLIN)
assert p.getsockopt(zmq.EVENTS) == zmq.POLLOUT
self.assertRaisesErrno(zmq.EINVAL, p.setsockopt, zmq.EVENTS, 2**7 - 1)
assert p.getsockopt(zmq.TYPE) == p.socket_type
assert p.getsockopt(zmq.TYPE) == zmq.PUB
assert s.getsockopt(zmq.TYPE) == s.socket_type
assert s.getsockopt(zmq.TYPE) == zmq.SUB
# check for overflow / wrong type:
errors = []
backref = {}
constants = zmq.constants
for name in constants.__all__:
value = getattr(constants, name)
if isinstance(value, int):
backref[value] = name
for opt in zmq.constants.SocketOption:
if opt._opt_type not in {
zmq.constants._OptType.int,
zmq.constants._OptType.int64,
}:
continue
if opt.name.startswith(
(
'HWM',
'ROUTER',
'XPUB',
'TCP',
'FAIL',
'REQ_',
'CURVE_',
'PROBE_ROUTER',
'IPC_FILTER',
'GSSAPI',
'STREAM_',
'VMCI_BUFFER_SIZE',
'VMCI_BUFFER_MIN_SIZE',
'VMCI_BUFFER_MAX_SIZE',
'VMCI_CONNECT_TIMEOUT',
'BLOCKY',
'IN_BATCH_SIZE',
'OUT_BATCH_SIZE',
'WSS_TRUST_SYSTEM',
'ONLY_FIRST_SUBSCRIBE',
'PRIORITY',
'RECONNECT_STOP',
)
):
# some sockopts are write-only
continue
try:
n = p.getsockopt(opt)
except zmq.ZMQError as e:
errors.append(f"getsockopt({opt!r}) raised {e}.")
else:
if n > 2**31:
errors.append(
f"getsockopt({opt!r}) returned a ridiculous value."
" It is probably the wrong type."
)
if errors:
self.fail('\n'.join([''] + errors))
def test_bad_sockopts(self):
"""Test that appropriate errors are raised on bad socket options"""
s = self.context.socket(zmq.PUB)
self.sockets.append(s)
s.setsockopt(zmq.LINGER, 0)
# unrecognized int sockopts pass through to libzmq, and should raise EINVAL
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5)
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999)
# but only int sockopts are allowed through this way, otherwise raise a TypeError
self.assertRaises(TypeError, s.setsockopt, 9999, b"5")
# some sockopts are valid in general, but not on every socket:
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi')
def test_sockopt_roundtrip(self):
"test set/getsockopt roundtrip."
p = self.context.socket(zmq.PUB)
self.sockets.append(p)
p.setsockopt(zmq.LINGER, 11)
assert p.getsockopt(zmq.LINGER) == 11
def test_send_unicode(self):
"test sending unicode objects"
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
self.sockets.extend([a, b])
u = "çπ§"
self.assertRaises(TypeError, a.send, u, copy=False)
self.assertRaises(TypeError, a.send, u, copy=True)
a.send_unicode(u)
s = b.recv()
assert s == u.encode('utf8')
assert s.decode('utf8') == u
a.send_unicode(u, encoding='utf16')
s = b.recv_unicode(encoding='utf16')
assert s == u
def test_send_multipart_check_type(self):
"check type on all frames in send_multipart"
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
self.sockets.extend([a, b])
self.assertRaises(TypeError, a.send_multipart, [b'a', 5])
a.send_multipart([b'b'])
rcvd = self.recv_multipart(b)
assert rcvd == [b'b']
@skip_pypy
def test_tracker(self):
"test the MessageTracker object for tracking when zmq is done with a buffer"
addr = 'tcp://127.0.0.1'
# get a port:
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
iface = "%s:%i" % (addr, port)
sock.close()
time.sleep(0.1)
a = self.context.socket(zmq.PUSH)
b = self.context.socket(zmq.PULL)
self.sockets.extend([a, b])
a.connect(iface)
time.sleep(0.1)
p1 = a.send(b'something', copy=False, track=True)
assert isinstance(p1, zmq.MessageTracker)
assert p1 is zmq._FINISHED_TRACKER
# small message, should start done
assert p1.done
# disable zero-copy threshold
a.copy_threshold = 0
p2 = a.send_multipart([b'something', b'else'], copy=False, track=True)
assert isinstance(p2, zmq.MessageTracker)
assert not p2.done
b.bind(iface)
msg = self.recv_multipart(b)
for i in range(10):
if p1.done:
break
time.sleep(0.1)
assert p1.done == True
assert msg == [b'something']
msg = self.recv_multipart(b)
for i in range(10):
if p2.done:
break
time.sleep(0.1)
assert p2.done == True
assert msg == [b'something', b'else']
m = zmq.Frame(b"again", copy=False, track=True)
assert m.tracker.done == False
p1 = a.send(m, copy=False)
p2 = a.send(m, copy=False)
assert m.tracker.done == False
assert p1.done == False
assert p2.done == False
msg = self.recv_multipart(b)
assert m.tracker.done == False
assert msg == [b'again']
msg = self.recv_multipart(b)
assert m.tracker.done == False
assert msg == [b'again']
assert p1.done == False
assert p2.done == False
m.tracker
del m
for i in range(10):
if p1.done:
break
time.sleep(0.1)
assert p1.done == True
assert p2.done == True
m = zmq.Frame(b'something', track=False)
self.assertRaises(ValueError, a.send, m, copy=False, track=True)
def test_close(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
s.close()
self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'')
self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'')
self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'')
self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf')
self.assertRaisesErrno(zmq.ENOTSOCK, s.recv)
ctx.term()
def test_attr(self):
"""set setting/getting sockopts as attributes"""
s = self.context.socket(zmq.DEALER)
self.sockets.append(s)
linger = 10
s.linger = linger
assert linger == s.linger
assert linger == s.getsockopt(zmq.LINGER)
assert s.fd == s.getsockopt(zmq.FD)
def test_bad_attr(self):
s = self.context.socket(zmq.DEALER)
self.sockets.append(s)
try:
s.apple = 'foo'
except AttributeError:
pass
else:
self.fail("bad setattr should have raised AttributeError")
try:
s.apple
except AttributeError:
pass
else:
self.fail("bad getattr should have raised AttributeError")
def test_subclass(self):
"""subclasses can assign attributes"""
class S(zmq.Socket):
a = None
def __init__(self, *a, **kw):
self.a = -1
super().__init__(*a, **kw)
s = S(self.context, zmq.REP)
self.sockets.append(s)
assert s.a == -1
s.a = 1
assert s.a == 1
a = s.a
assert a == 1
def test_recv_multipart(self):
a, b = self.create_bound_pair()
msg = b'hi'
for i in range(3):
a.send(msg)
time.sleep(0.1)
for i in range(3):
assert self.recv_multipart(b) == [msg]
def test_close_after_destroy(self):
"""s.close() after ctx.destroy() should be fine"""
ctx = self.Context()
s = ctx.socket(zmq.REP)
ctx.destroy()
# reaper is not instantaneous
time.sleep(1e-2)
s.close()
assert s.closed
def test_poll(self):
a, b = self.create_bound_pair()
time.time()
evt = a.poll(POLL_TIMEOUT)
assert evt == 0
evt = a.poll(POLL_TIMEOUT, zmq.POLLOUT)
assert evt == zmq.POLLOUT
msg = b'hi'
a.send(msg)
evt = b.poll(POLL_TIMEOUT)
assert evt == zmq.POLLIN
msg2 = self.recv(b)
evt = b.poll(POLL_TIMEOUT)
assert evt == 0
assert msg2 == msg
def test_ipc_path_max_length(self):
"""IPC_PATH_MAX_LEN is a sensible value"""
if zmq.IPC_PATH_MAX_LEN == 0:
raise SkipTest("IPC_PATH_MAX_LEN undefined")
msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN
assert zmq.IPC_PATH_MAX_LEN > 30, msg
assert zmq.IPC_PATH_MAX_LEN < 1025, msg
def test_ipc_path_max_length_msg(self):
if zmq.IPC_PATH_MAX_LEN == 0:
raise SkipTest("IPC_PATH_MAX_LEN undefined")
s = self.context.socket(zmq.PUB)
self.sockets.append(s)
try:
s.bind('ipc://{}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1)))
except zmq.ZMQError as e:
assert str(zmq.IPC_PATH_MAX_LEN) in e.strerror
@mark.skipif(windows, reason="ipc not supported on Windows.")
def test_ipc_path_no_such_file_or_directory_message(self):
"""Display the ipc path in case of an ENOENT exception"""
s = self.context.socket(zmq.PUB)
self.sockets.append(s)
invalid_path = '/foo/bar'
with pytest.raises(zmq.ZMQError) as error:
s.bind(f'ipc://{invalid_path}')
assert error.value.errno == errno.ENOENT
error_message = str(error.value)
assert invalid_path in error_message
assert "no such file or directory" in error_message.lower()
def test_hwm(self):
zmq3 = zmq.zmq_version_info()[0] >= 3
for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER):
s = self.context.socket(stype)
s.hwm = 100
assert s.hwm == 100
if zmq3:
try:
assert s.sndhwm == 100
except AttributeError:
pass
try:
assert s.rcvhwm == 100
except AttributeError:
pass
s.close()
def test_copy(self):
s = self.socket(zmq.PUB)
scopy = copy.copy(s)
sdcopy = copy.deepcopy(s)
assert scopy._shadow
assert sdcopy._shadow
assert s.underlying == scopy.underlying
assert s.underlying == sdcopy.underlying
s.close()
def test_send_buffer(self):
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
for buffer_type in (memoryview, bytearray):
rawbytes = str(buffer_type).encode('ascii')
msg = buffer_type(rawbytes)
a.send(msg)
recvd = b.recv()
assert recvd == rawbytes
def test_shadow(self):
p = self.socket(zmq.PUSH)
p.bind("tcp://127.0.0.1:5555")
p2 = zmq.Socket.shadow(p.underlying)
assert p.underlying == p2.underlying
s = self.socket(zmq.PULL)
s2 = zmq.Socket.shadow(s)
assert s2._shadow_obj is s
assert s.underlying != p.underlying
assert s2.underlying == s.underlying
s3 = zmq.Socket(s)
assert s3._shadow_obj is s
assert s3.underlying == s.underlying
s2.connect("tcp://127.0.0.1:5555")
sent = b'hi'
p2.send(sent)
rcvd = self.recv(s2)
assert rcvd == sent
def test_shadow_pyczmq(self):
try:
from pyczmq import zctx, zsocket
except Exception:
raise SkipTest("Requires pyczmq")
ctx = zctx.new()
ca = zsocket.new(ctx, zmq.PUSH)
cb = zsocket.new(ctx, zmq.PULL)
a = zmq.Socket.shadow(ca)
b = zmq.Socket.shadow(cb)
a.bind("inproc://a")
b.connect("inproc://a")
a.send(b'hi')
rcvd = self.recv(b)
assert rcvd == b'hi'
def test_subscribe_method(self):
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
sub.subscribe('prefix')
sub.subscribe = 'c'
p = zmq.Poller()
p.register(sub, zmq.POLLIN)
# wait for subscription handshake
for i in range(100):
pub.send(b'canary')
events = p.poll(250)
if events:
break
self.recv(sub)
pub.send(b'prefixmessage')
msg = self.recv(sub)
assert msg == b'prefixmessage'
sub.unsubscribe('prefix')
pub.send(b'prefixmessage')
events = p.poll(1000)
assert events == []
# CI often can't handle how much memory PyPy uses on this test
@mark.skipif(
(pypy and on_ci) or (sys.maxsize < 2**32) or (windows),
reason="only run on 64b and not on CI.",
)
@mark.large
def test_large_send(self):
c = os.urandom(1)
N = 2**31 + 1
try:
buf = c * N
except MemoryError as e:
raise SkipTest("Not enough memory: %s" % e)
a, b = self.create_bound_pair()
try:
a.send(buf, copy=False)
rcvd = b.recv(copy=False)
except MemoryError as e:
raise SkipTest("Not enough memory: %s" % e)
# sample the front and back of the received message
# without checking the whole content
byte = ord(c)
view = memoryview(rcvd)
assert len(view) == N
assert view[0] == byte
assert view[-1] == byte
def test_custom_serialize(self):
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
def serialize(msg):
frames = []
frames.extend(msg.get('identities', []))
content = json.dumps(msg['content']).encode('utf8')
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode('utf8'))
return {
'identities': identities,
'content': content,
}
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
a.send_serialized(msg, serialize)
recvd = b.recv_serialized(deserialize)
assert recvd['content'] == msg['content']
assert recvd['identities']
# bounce back, tests identities
b.send_serialized(recvd, serialize)
r2 = a.recv_serialized(deserialize)
assert r2['content'] == msg['content']
assert not r2['identities']
if have_gevent and not windows:
import gevent
class TestSocketGreen(GreenTest, TestSocket):
test_bad_attr = GreenTest.skip_green
test_close_after_destroy = GreenTest.skip_green
_repr_cls = "zmq.green.Socket"
def test_timeout(self):
a, b = self.create_bound_pair()
g = gevent.spawn_later(0.5, lambda: a.send(b'hi'))
timeout = gevent.Timeout(0.1)
timeout.start()
self.assertRaises(gevent.Timeout, b.recv)
g.kill()
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_warn_set_timeo(self):
s = self.context.socket(zmq.REQ)
with warnings.catch_warnings(record=True) as w:
s.rcvtimeo = 5
s.close()
assert len(w) == 1
assert w[0].category == UserWarning
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_warn_get_timeo(self):
s = self.context.socket(zmq.REQ)
with warnings.catch_warnings(record=True) as w:
s.sndtimeo
s.close()
assert len(w) == 1
assert w[0].category == UserWarning

View File

@@ -0,0 +1,9 @@
from zmq.ssh.tunnel import select_random_ports
def test_random_ports():
for i in range(4096):
ports = select_random_ports(10)
assert len(ports) == 10
for p in ports:
assert ports.count(p) == 1

View File

@@ -0,0 +1,43 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
import zmq
from zmq.sugar import version
class TestVersion(TestCase):
def test_pyzmq_version(self):
vs = zmq.pyzmq_version()
vs2 = zmq.__version__
assert isinstance(vs, str)
if zmq.__revision__:
assert vs == '@'.join(vs2, zmq.__revision__)
else:
assert vs == vs2
if version.VERSION_EXTRA:
assert version.VERSION_EXTRA in vs
assert version.VERSION_EXTRA in vs2
def test_pyzmq_version_info(self):
info = zmq.pyzmq_version_info()
assert isinstance(info, tuple)
for n in info[:3]:
assert isinstance(n, int)
if version.VERSION_EXTRA:
assert len(info) == 4
assert info[-1] == float('inf')
else:
assert len(info) == 3
def test_zmq_version_info(self):
info = zmq.zmq_version_info()
assert isinstance(info, tuple)
for n in info[:3]:
assert isinstance(n, int)
def test_zmq_version(self):
v = zmq.zmq_version()
assert isinstance(v, str)

View File

@@ -0,0 +1,58 @@
import sys
import time
from functools import wraps
from pytest import mark
from zmq.tests import BaseZMQTestCase
from zmq.utils.win32 import allow_interrupt
def count_calls(f):
@wraps(f)
def _(*args, **kwds):
try:
return f(*args, **kwds)
finally:
_.__calls__ += 1
_.__calls__ = 0
return _
@mark.new_console
class TestWindowsConsoleControlHandler(BaseZMQTestCase):
@mark.new_console
@mark.skipif(not sys.platform.startswith('win'), reason='Windows only test')
def test_handler(self):
@count_calls
def interrupt_polling():
print('Caught CTRL-C!')
from ctypes import windll
from ctypes.wintypes import BOOL, DWORD
kernel32 = windll.LoadLibrary('kernel32')
# <http://msdn.microsoft.com/en-us/library/ms683155.aspx>
GenerateConsoleCtrlEvent = kernel32.GenerateConsoleCtrlEvent
GenerateConsoleCtrlEvent.argtypes = (DWORD, DWORD)
GenerateConsoleCtrlEvent.restype = BOOL
# Simulate CTRL-C event while handler is active.
try:
with allow_interrupt(interrupt_polling) as context:
result = GenerateConsoleCtrlEvent(0, 0)
# Sleep so that we give time to the handler to
# capture the Ctrl-C event.
time.sleep(0.5)
except KeyboardInterrupt:
pass
else:
if result == 0:
raise OSError()
else:
self.fail('Expecting `KeyboardInterrupt` exception!')
# Make sure our handler was called.
assert interrupt_polling.__calls__ == 1

View File

@@ -0,0 +1,65 @@
"""Test Z85 encoding
confirm values and roundtrip with test values from the reference implementation.
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
from zmq.utils import z85
class TestZ85(TestCase):
def test_client_public(self):
client_public = (
b"\xBB\x88\x47\x1D\x65\xE2\x65\x9B"
b"\x30\xC5\x5A\x53\x21\xCE\xBB\x5A"
b"\xAB\x2B\x70\xA3\x98\x64\x5C\x26"
b"\xDC\xA2\xB2\xFC\xB4\x3F\xC5\x18"
)
encoded = z85.encode(client_public)
assert encoded == b"Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID"
decoded = z85.decode(encoded)
assert decoded == client_public
def test_client_secret(self):
client_secret = (
b"\x7B\xB8\x64\xB4\x89\xAF\xA3\x67"
b"\x1F\xBE\x69\x10\x1F\x94\xB3\x89"
b"\x72\xF2\x48\x16\xDF\xB0\x1B\x51"
b"\x65\x6B\x3F\xEC\x8D\xFD\x08\x88"
)
encoded = z85.encode(client_secret)
assert encoded == b"D:)Q[IlAW!ahhC2ac:9*A}h:p?([4%wOTJ%JR%cs"
decoded = z85.decode(encoded)
assert decoded == client_secret
def test_server_public(self):
server_public = (
b"\x54\xFC\xBA\x24\xE9\x32\x49\x96"
b"\x93\x16\xFB\x61\x7C\x87\x2B\xB0"
b"\xC1\xD1\xFF\x14\x80\x04\x27\xC5"
b"\x94\xCB\xFA\xCF\x1B\xC2\xD6\x52"
)
encoded = z85.encode(server_public)
assert encoded == b"rq:rM>}U?@Lns47E1%kR.o@n%FcmmsL/@{H8]yf7"
decoded = z85.decode(encoded)
assert decoded == server_public
def test_server_secret(self):
server_secret = (
b"\x8E\x0B\xDD\x69\x76\x28\xB9\x1D"
b"\x8F\x24\x55\x87\xEE\x95\xC5\xB0"
b"\x4D\x48\x96\x3F\x79\x25\x98\x77"
b"\xB4\x9C\xD9\x06\x3A\xEA\xD3\xB7"
)
encoded = z85.encode(server_secret)
assert encoded == b"JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6"
decoded = z85.decode(encoded)
assert decoded == server_secret

View File

@@ -0,0 +1,159 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import asyncio
import logging
import warnings
import pytest
import zmq
import zmq.asyncio
try:
import tornado
from zmq.eventloop import zmqstream
except ImportError:
tornado = None # type: ignore
pytestmark = pytest.mark.usefixtures("io_loop")
@pytest.fixture
async def push_pull(socket):
push = zmqstream.ZMQStream(socket(zmq.PUSH))
pull = zmqstream.ZMQStream(socket(zmq.PULL))
port = push.bind_to_random_port('tcp://127.0.0.1')
pull.connect('tcp://127.0.0.1:%i' % port)
return (push, pull)
@pytest.fixture
def push(push_pull):
push, pull = push_pull
return push
@pytest.fixture
def pull(push_pull):
push, pull = push_pull
return pull
async def test_callable_check(pull):
"""Ensure callable check works."""
pull.on_send(lambda *args: None)
pull.on_recv(lambda *args: None)
with pytest.raises(AssertionError):
pull.on_recv(1)
with pytest.raises(AssertionError):
pull.on_send(1)
with pytest.raises(AssertionError):
pull.on_recv(zmq)
async def test_on_recv_basic(push, pull):
sent = [b'basic']
push.send_multipart(sent)
f = asyncio.Future()
def callback(msg):
f.set_result(msg)
pull.on_recv(callback)
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == sent
async def test_on_recv_wake(push, pull):
sent = [b'wake']
f = asyncio.Future()
pull.on_recv(f.set_result)
await asyncio.sleep(0.5)
push.send_multipart(sent)
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == sent
async def test_on_recv_async(push, pull):
if tornado.version_info < (5,):
pytest.skip()
sent = [b'wake']
f = asyncio.Future()
async def callback(msg):
await asyncio.sleep(0.1)
f.set_result(msg)
pull.on_recv(callback)
await asyncio.sleep(0.5)
push.send_multipart(sent)
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == sent
async def test_on_recv_async_error(push, pull, caplog):
sent = [b'wake']
f = asyncio.Future()
async def callback(msg):
f.set_result(msg)
1 / 0
pull.on_recv(callback)
await asyncio.sleep(0.1)
with caplog.at_level(logging.ERROR, logger=zmqstream.gen_log.name):
push.send_multipart(sent)
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == sent
# logging error takes a tick later
await asyncio.sleep(0.5)
messages = [
x.message
for x in caplog.get_records("call")
if x.name == zmqstream.gen_log.name
]
assert "Uncaught exception in ZMQStream callback" in "\n".join(messages)
async def test_shadow_socket(context):
with context.socket(zmq.PUSH, socket_class=zmq.asyncio.Socket) as socket:
with pytest.warns(RuntimeWarning):
stream = zmqstream.ZMQStream(socket)
assert type(stream.socket) is zmq.Socket
assert stream.socket.underlying == socket.underlying
stream.close()
async def test_shadow_socket_close(context, caplog):
with context.socket(zmq.PUSH) as push, context.socket(zmq.PULL) as pull:
push.linger = pull.linger = 0
port = push.bind_to_random_port('tcp://127.0.0.1')
pull.connect(f'tcp://127.0.0.1:{port}')
shadow_pull = zmq.Socket.shadow(pull)
stream = zmqstream.ZMQStream(shadow_pull)
# send some messages
for i in range(10):
push.send_string(str(i))
# make sure at least one message has been delivered
pull.recv()
# register callback
# this should schedule event callback on the next tick
stream.on_recv(print)
# close the shadowed socket
pull.close()
# run the event loop, which should see some events on the shadow socket
# but the socket has been closed!
with warnings.catch_warnings(record=True) as records:
await asyncio.sleep(0.2)
warning_text = "\n".join(str(r.message) for r in records)
assert "after closing socket" in warning_text
assert "closed socket" in caplog.text