asm
This commit is contained in:
416
asm/venv/lib/python3.11/site-packages/zmq/tests/test_auth.py
Normal file
416
asm/venv/lib/python3.11/site-packages/zmq/tests/test_auth.py
Normal file
@@ -0,0 +1,416 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
import zmq.auth
|
||||
from zmq.tests import SkipTest, skip_pypy
|
||||
|
||||
try:
|
||||
import tornado
|
||||
except ImportError:
|
||||
tornado = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Context(event_loop):
|
||||
return zmq.asyncio.Context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_certs(tmpdir):
|
||||
"""Create CURVE certificates for a test"""
|
||||
|
||||
# Create temporary CURVE key pairs for this test run. We create all keys in a
|
||||
# temp directory and then move them into the appropriate private or public
|
||||
# directory.
|
||||
base_dir = str(tmpdir.join("certs").mkdir())
|
||||
keys_dir = os.path.join(base_dir, "certificates")
|
||||
public_keys_dir = os.path.join(base_dir, 'public_keys')
|
||||
secret_keys_dir = os.path.join(base_dir, 'private_keys')
|
||||
|
||||
os.mkdir(keys_dir)
|
||||
os.mkdir(public_keys_dir)
|
||||
os.mkdir(secret_keys_dir)
|
||||
|
||||
server_public_file, server_secret_file = zmq.auth.create_certificates(
|
||||
keys_dir, "server"
|
||||
)
|
||||
client_public_file, client_secret_file = zmq.auth.create_certificates(
|
||||
keys_dir, "client"
|
||||
)
|
||||
for key_file in os.listdir(keys_dir):
|
||||
if key_file.endswith(".key"):
|
||||
shutil.move(
|
||||
os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, '.')
|
||||
)
|
||||
|
||||
for key_file in os.listdir(keys_dir):
|
||||
if key_file.endswith(".key_secret"):
|
||||
shutil.move(
|
||||
os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, '.')
|
||||
)
|
||||
|
||||
return (public_keys_dir, secret_keys_dir)
|
||||
|
||||
|
||||
def load_certs(secret_keys_dir):
|
||||
"""Return server and client certificate keys"""
|
||||
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
|
||||
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
|
||||
|
||||
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
|
||||
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
|
||||
|
||||
return server_public, server_secret, client_public, client_secret
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def public_keys_dir(create_certs):
|
||||
public_keys_dir, secret_keys_dir = create_certs
|
||||
return public_keys_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def secret_keys_dir(create_certs):
|
||||
public_keys_dir, secret_keys_dir = create_certs
|
||||
return secret_keys_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def certs(secret_keys_dir):
|
||||
return load_certs(secret_keys_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def _async_setup(request, event_loop):
|
||||
"""pytest doesn't support async setup/teardown"""
|
||||
instance = request.instance
|
||||
await instance.async_setup()
|
||||
yield
|
||||
# make sure our teardown runs before the loop closes
|
||||
instance.async_teardown()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_async_setup")
|
||||
class AuthTest:
|
||||
auth = None
|
||||
|
||||
async def async_setup(self):
|
||||
self.context = zmq.asyncio.Context()
|
||||
if zmq.zmq_version_info() < (4, 0):
|
||||
raise SkipTest("security is new in libzmq 4.0")
|
||||
try:
|
||||
zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("security requires libzmq to have curve support")
|
||||
# enable debug logging while we run tests
|
||||
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
|
||||
self.auth = self.make_auth()
|
||||
await self.start_auth()
|
||||
|
||||
def async_teardown(self):
|
||||
if self.auth:
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
self.context.term()
|
||||
|
||||
def make_auth(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def start_auth(self):
|
||||
self.auth.start()
|
||||
|
||||
async def can_connect(self, server, client, timeout=1000):
|
||||
"""Check if client can connect to server using tcp transport"""
|
||||
result = False
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
msg = [b"Hello World"]
|
||||
# run poll on server twice
|
||||
# to flush spurious events
|
||||
await server.poll(100, zmq.POLLOUT)
|
||||
|
||||
if await server.poll(timeout, zmq.POLLOUT):
|
||||
try:
|
||||
await server.send_multipart(msg, zmq.NOBLOCK)
|
||||
except zmq.Again:
|
||||
warnings.warn("server set POLLOUT, but cannot send", RuntimeWarning)
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
if await client.poll(timeout):
|
||||
try:
|
||||
rcvd_msg = await client.recv_multipart(zmq.NOBLOCK)
|
||||
except zmq.Again:
|
||||
warnings.warn("client set POLLIN, but cannot recv", RuntimeWarning)
|
||||
else:
|
||||
assert rcvd_msg == msg
|
||||
result = True
|
||||
return result
|
||||
|
||||
@contextmanager
|
||||
def push_pull(self):
|
||||
with self.context.socket(zmq.PUSH) as server, self.context.socket(
|
||||
zmq.PULL
|
||||
) as client:
|
||||
server.linger = 0
|
||||
server.sndtimeo = 2000
|
||||
client.linger = 0
|
||||
client.rcvtimeo = 2000
|
||||
yield server, client
|
||||
|
||||
@contextmanager
|
||||
def curve_push_pull(self, certs, client_key="ok"):
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
with self.push_pull() as (server, client):
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
if client_key is not None:
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
if client_key == "ok":
|
||||
client.curve_serverkey = server_public
|
||||
else:
|
||||
private, public = zmq.curve_keypair()
|
||||
client.curve_serverkey = public
|
||||
yield (server, client)
|
||||
|
||||
async def test_null(self):
|
||||
"""threaded auth - NULL"""
|
||||
# A default NULL connection should always succeed, and not
|
||||
# go through our authentication infrastructure at all.
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
# use a new context, so ZAP isn't inherited
|
||||
self.context.term()
|
||||
self.context = zmq.asyncio.Context()
|
||||
|
||||
with self.push_pull() as (server, client):
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet. The client connection
|
||||
# should still be allowed.
|
||||
with self.push_pull() as (server, client):
|
||||
server.zap_domain = b'global'
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
async def test_deny(self):
|
||||
# deny 127.0.0.1, connection should fail
|
||||
self.auth.deny('127.0.0.1')
|
||||
with pytest.raises(ValueError):
|
||||
self.auth.allow("127.0.0.2")
|
||||
with self.push_pull() as (server, client):
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet.
|
||||
server.zap_domain = b'global'
|
||||
assert not await self.can_connect(server, client, timeout=100)
|
||||
|
||||
async def test_allow(self):
|
||||
# allow 127.0.0.1, connection should pass
|
||||
self.auth.allow('127.0.0.1')
|
||||
with pytest.raises(ValueError):
|
||||
self.auth.deny("127.0.0.2")
|
||||
with self.push_pull() as (server, client):
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet.
|
||||
server.zap_domain = b'global'
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"enabled, password, success",
|
||||
[
|
||||
(True, "correct", True),
|
||||
(False, "correct", False),
|
||||
(True, "incorrect", False),
|
||||
],
|
||||
)
|
||||
async def test_plain(self, enabled, password, success):
|
||||
"""threaded auth - PLAIN"""
|
||||
|
||||
# Try PLAIN authentication - without configuring server, connection should fail
|
||||
with self.push_pull() as (server, client):
|
||||
server.plain_server = True
|
||||
if password:
|
||||
client.plain_username = b'admin'
|
||||
client.plain_password = password.encode("ascii")
|
||||
if enabled:
|
||||
self.auth.configure_plain(domain='*', passwords={'admin': 'correct'})
|
||||
if success:
|
||||
assert await self.can_connect(server, client)
|
||||
else:
|
||||
assert not await self.can_connect(server, client, timeout=100)
|
||||
|
||||
# Remove authenticator and check that a normal connection works
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
with self.push_pull() as (server, client):
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_key, location, success",
|
||||
[
|
||||
('ok', zmq.auth.CURVE_ALLOW_ANY, True),
|
||||
('ok', "public_keys", True),
|
||||
('bad', "public_keys", False),
|
||||
(None, "public_keys", False),
|
||||
],
|
||||
)
|
||||
async def test_curve(self, certs, public_keys_dir, client_key, location, success):
|
||||
"""threaded auth - CURVE"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
|
||||
# Try CURVE authentication - without configuring server, connection should fail
|
||||
with self.curve_push_pull(certs, client_key=client_key) as (server, client):
|
||||
if location:
|
||||
if location == 'public_keys':
|
||||
location = public_keys_dir
|
||||
self.auth.configure_curve(domain='*', location=location)
|
||||
if success:
|
||||
assert await self.can_connect(server, client, timeout=100)
|
||||
else:
|
||||
assert not await self.can_connect(server, client, timeout=100)
|
||||
|
||||
# Remove authenticator and check that a normal connection works
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
# Try connecting using NULL and no authentication enabled, connection should pass
|
||||
with self.push_pull() as (server, client):
|
||||
assert await self.can_connect(server, client)
|
||||
|
||||
@pytest.mark.parametrize("key", ["ok", "wrong"])
|
||||
@pytest.mark.parametrize("async_", [True, False])
|
||||
async def test_curve_callback(self, certs, key, async_):
|
||||
"""threaded auth - CURVE with callback authentication"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
class CredentialsProvider:
|
||||
def __init__(self):
|
||||
if key == "ok":
|
||||
self.client = client_public
|
||||
else:
|
||||
self.client = server_public
|
||||
|
||||
def callback(self, domain, key):
|
||||
if key == self.client:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def async_callback(self, domain, key):
|
||||
await asyncio.sleep(0.1)
|
||||
if key == self.client:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
if async_:
|
||||
CredentialsProvider.callback = CredentialsProvider.async_callback
|
||||
provider = CredentialsProvider()
|
||||
self.auth.configure_curve_callback(credentials_provider=provider)
|
||||
with self.curve_push_pull(certs) as (server, client):
|
||||
if key == "ok":
|
||||
assert await self.can_connect(server, client)
|
||||
else:
|
||||
assert not await self.can_connect(server, client, timeout=200)
|
||||
|
||||
@skip_pypy
|
||||
async def test_curve_user_id(self, certs, public_keys_dir):
|
||||
"""threaded auth - CURVE"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
self.auth.configure_curve(domain='*', location=public_keys_dir)
|
||||
# reverse server-client relationship, so server is PULL
|
||||
with self.push_pull() as (client, server):
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
|
||||
assert await self.can_connect(client, server)
|
||||
|
||||
# test default user-id map
|
||||
await client.send(b'test')
|
||||
msg = await server.recv(copy=False)
|
||||
assert msg.bytes == b'test'
|
||||
try:
|
||||
user_id = msg.get('User-Id')
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
else:
|
||||
assert user_id == client_public.decode("utf8")
|
||||
|
||||
# test custom user-id map
|
||||
self.auth.curve_user_id = lambda client_key: 'custom'
|
||||
|
||||
with self.context.socket(zmq.PUSH) as client2:
|
||||
client2.curve_publickey = client_public
|
||||
client2.curve_secretkey = client_secret
|
||||
client2.curve_serverkey = server_public
|
||||
assert await self.can_connect(client2, server)
|
||||
|
||||
await client2.send(b'test2')
|
||||
msg = await server.recv(copy=False)
|
||||
assert msg.bytes == b'test2'
|
||||
try:
|
||||
user_id = msg.get('User-Id')
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
else:
|
||||
assert user_id == 'custom'
|
||||
|
||||
|
||||
class TestThreadAuthentication(AuthTest):
|
||||
"""Test authentication running in a thread"""
|
||||
|
||||
def make_auth(self):
|
||||
from zmq.auth.thread import ThreadAuthenticator
|
||||
|
||||
return ThreadAuthenticator(self.context)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == 'win32' and sys.version_info < (3, 7),
|
||||
reason="flaky event loop cleanup on windows+py36",
|
||||
)
|
||||
class TestAsyncioAuthentication(AuthTest):
|
||||
"""Test authentication running in a thread"""
|
||||
|
||||
def make_auth(self):
|
||||
from zmq.auth.asyncio import AsyncioAuthenticator
|
||||
|
||||
return AsyncioAuthenticator(self.context)
|
||||
|
||||
|
||||
async def test_ioloop_authenticator(context, event_loop, io_loop):
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
with warnings.catch_warnings():
|
||||
from zmq.auth.ioloop import IOLoopAuthenticator
|
||||
|
||||
auth = IOLoopAuthenticator(context)
|
||||
assert auth.context is context
|
||||
|
||||
loop = IOLoop(make_current=False)
|
||||
with pytest.warns(DeprecationWarning):
|
||||
auth = IOLoopAuthenticator(io_loop=loop)
|
||||
Reference in New Issue
Block a user