From 80df3a8346f6a4c12c6ac59d6f8793668f25d7d2 Mon Sep 17 00:00:00 2001 From: Giuseppe Carboni Date: Wed, 7 Jan 2026 13:34:14 +0100 Subject: [PATCH 1/6] First implementation of sending commands feature --- discos_client/__init__.py | 14 +- discos_client/client.py | 264 +++++++++++++++++++-------- discos_client/servers/srt/server.key | 9 + discos_client/utils.py | 26 ++- pyproject.toml | 4 +- scripts/discos-keygen | 69 +++++++ tests/test_client.py | 205 +++++++++++++++++---- tests/test_keys/dummy.key | 7 + tests/test_keys/dummy.key_secret | 8 + 9 files changed, 493 insertions(+), 113 deletions(-) create mode 100644 discos_client/servers/srt/server.key create mode 100644 scripts/discos-keygen create mode 100644 tests/test_keys/dummy.key create mode 100644 tests/test_keys/dummy.key_secret diff --git a/discos_client/__init__.py b/discos_client/__init__.py index 0bd20d4..97407e0 100644 --- a/discos_client/__init__.py +++ b/discos_client/__init__.py @@ -1,27 +1,31 @@ from functools import partial -from .client import DISCOSClient, DEFAULT_PORT +from .client import DISCOSClient, DEFAULT_SUB_PORT, DEFAULT_REQ_PORT SRTClient = partial( DISCOSClient, address="192.168.200.203", - port=DEFAULT_PORT, + sub_port=DEFAULT_SUB_PORT, + req_port=DEFAULT_REQ_PORT, telescope="SRT" ) MedicinaClient = partial( DISCOSClient, address="192.168.1.100", - port=DEFAULT_PORT, + sub_port=DEFAULT_SUB_PORT, + req_port=DEFAULT_REQ_PORT, telescope="Medicina" ) NotoClient = partial( DISCOSClient, address="192.167.187.17", - port=DEFAULT_PORT, + sub_port=DEFAULT_SUB_PORT, + req_port=DEFAULT_REQ_PORT, telescope="Noto" ) del partial -del DEFAULT_PORT +del DEFAULT_SUB_PORT +del DEFAULT_REQ_PORT __all__ = [ "DISCOSClient", diff --git a/discos_client/client.py b/discos_client/client.py index bdf3662..02639c8 100644 --- a/discos_client/client.py +++ b/discos_client/client.py @@ -3,18 +3,22 @@ import weakref from threading import Thread, Lock, Event from collections import defaultdict +from typing import Any import zmq +from zmq.utils.monitor import recv_monitor_message from .namespace import DISCOSNamespace -from .utils import rand_id +from .utils import rand_id, get_auth_keys from .initializer import NSInitializer __all__ = [ - "DEFAULT_PORT", + "DEFAULT_SUB_PORT", + "DEFAULT_REQ_PORT", "DISCOSClient" ] -DEFAULT_PORT = 16000 +DEFAULT_SUB_PORT = 16000 +DEFAULT_REQ_PORT = 16010 class DISCOSClient: @@ -27,7 +31,8 @@ def __init__( self, *topics: str, address: str, - port: int, + sub_port: int, + req_port: int | None = None, telescope: str | None = None ) -> None: """ @@ -35,49 +40,50 @@ def __init__( :param topics: topic names to subscribe to. :param address: IP address to subscribe to. - :param port: TCP port to subscribe to. + :param sub_port: TCP port where the subscriber socket will connect. + :param req_port: TCP port where the requester socket will connect. :param telescope: name of the telescope the client is connecting to. :raises ValueError: If one or more given topics are not known. """ + if telescope not in ("Medicina", "Noto", "SRT", None): + raise ValueError(f"Unknown telescope: '{telescope}'") initializer = NSInitializer(telescope) - valid_topics = initializer.get_topics() - invalid = [t for t in topics if t not in valid_topics] - if invalid: - if len(invalid) > 1: - invalid = \ - f"""s '{"', '".join(invalid[:-1])}'""" \ - f" and '{invalid[-1]}' are" - else: - invalid = f""" '{invalid[0]}' is""" - raise ValueError( - f"Topic{invalid} not known, choose among " - f"""'{"', '".join(valid_topics[:-1])} and """ - f"'{valid_topics[-1]}'" - ) - if not topics: - topics = initializer.get_topics() - self._topics = topics + self._topics = self.__validate_topics__(initializer, topics) self._client_id = rand_id() - self._event = Event() + self._stop = Event() self._context = zmq.Context() - self._socket = self._context.socket(zmq.SUB) - self._socket.setsockopt(zmq.LINGER, 0) - self._socket.setsockopt(zmq.RCVTIMEO, 10) - self._socket.connect(f"tcp://{address}:{port}") + + events = {} + events["stop"] = self._stop + + self._sub = self._context.socket(zmq.SUB) + self._sub.setsockopt(zmq.LINGER, 0) + self._sub.setsockopt(zmq.RCVTIMEO, 10) + self._sub.setsockopt(zmq.RECONNECT_IVL, 1000) + self._sub.setsockopt(zmq.CONNECT_TIMEOUT, 500) + self._sub.connect(f"tcp://{address}:{sub_port}") + + sockets = {} + sockets["sub"] = self._sub + + if req_port and telescope: + self.__init_req_socket__( + address, req_port, telescope, events, sockets + ) self._locks = defaultdict(Lock) for topic in self._topics: self.__dict__[topic] = initializer.initialize(topic) - self._receiver = Thread( - target=self.__receive__, + self._updater = Thread( + target=self.__update__, args=( - self._socket, - self._locks, self._client_id, + sockets, + self._locks, self.__dict__, - self._event + events, ), daemon=True ) @@ -85,67 +91,181 @@ def __init__( self._finalizer = weakref.finalize( self, self.__cleanup__, - self._event, - self._socket, - self._context, - self._receiver + self._stop, + self._updater, + sockets, + self._context ) - self._receiver.start() + self._updater.start() for topic in self._topics: - self._socket.subscribe(f"{self._client_id}{topic}") + self._sub.subscribe(f"{self._client_id}{topic}") + + def __init_req_socket__( + self, + address: str, + req_port: int, + telescope: str, + events: dict[str, Event], + sockets: dict[str, zmq.Socket] + ) -> None: + try: + client_public, client_secret, server_public = get_auth_keys( + telescope + ) + except OSError: + # A curve key is missing, this + # telemetry and will not be able to send commands + return + self._req = self._context.socket(zmq.REQ) + self._req.setsockopt(zmq.LINGER, 0) + self._req.setsockopt(zmq.IMMEDIATE, 1) + self._req.setsockopt(zmq.SNDTIMEO, 0) + self._req.setsockopt(zmq.RECONNECT_IVL, 1000) + self._req.setsockopt(zmq.CONNECT_TIMEOUT, 500) + self._req.setsockopt(zmq.HEARTBEAT_IVL, 1000) + self._req.setsockopt(zmq.HEARTBEAT_TIMEOUT, 1000) + self._req.curve_publickey = client_public + self._req.curve_secretkey = client_secret + self._req.curve_serverkey = server_public + self._mon = self._req.get_monitor_socket() + self._online = Event() + events["online"] = self._online + self._req.connect(f"tcp://{address}:{req_port}") + sockets["req"] = self._req + sockets["mon"] = self._mon + self.command = self.__command__ + + @staticmethod + def __validate_topics__( + initializer: NSInitializer, + topics: tuple[str] + ) -> list[str]: + valid_topics = initializer.get_topics() + invalid = [t for t in topics if t not in valid_topics] + if not invalid: + return topics or valid_topics + + if len(invalid) > 1: + invalid = f"""s '{"', '".join(invalid[:-1])}'""" \ + f" and '{invalid[-1]}' are" + else: + invalid = f""" '{invalid[0]}' is""" + + raise ValueError( + f"Topic{invalid} not known, choose among " + f"""'{"', '".join(valid_topics[:-1])} and """ + f"'{valid_topics[-1]}'" + ) @staticmethod def __cleanup__( - event: Event, - socket: zmq.Socket, - context: zmq.Context, - receiver: Thread + stop: Event, + updater: Thread, + sockets: dict[str, zmq.Socket], + context: zmq.Context ) -> None: """ - Joins the receiver thread and closes the ZMQ socket and context. + Joins the updater thread and closes the ZMQ sockets and context. - :param event: the Event object that will stop the receiver thread. - :param socket: the ZMQ socket object. + :param stop: the Event object that will stop the updater thread. + :param sub: the ZMQ SUB socket object. :param context: the ZMQ context object. - :param receiver: the receiver thread object. + :param updater: the updater thread object. """ - event.set() - receiver.join() - socket.close() + stop.set() + try: + updater.join() + except RuntimeError: # pragma: no cover + pass + for _, socket in sockets.items(): + socket.disable_monitor() + socket.close() context.term() @staticmethod - def __receive__( - socket: zmq.Socket, - locks: dict[str, Lock], + def __update__( client_id: str, - d: dict[str, DISCOSNamespace], - event: Event + sockets: dict[str, zmq.Socket], + locks: dict[str, Lock], + namespaces: dict[str, DISCOSNamespace], + events: dict[str, Event] ) -> None: """ - Loops continuously waiting for new ZMQ messages. + Loops continuously waiting for new ZMQ messages and events. - :param socket: The ZMQ socket object. - :param locks: The locks dictionary, used for thread synchronization. :param client_id: The random string identifying the client. - :param d: The client __dict__ object. - :param event: The Event object that will break the receiver loop. + :param sockets: The dictionary containing the ZMQ sockets. + :param locks: The locks dictionary, used for thread synchronization. + :param namespaces: The client __dict__ object, containing the + DISCOSNamespaces. + :param events: The dictionary containing the Event objects for + synchronization. """ - while not event.is_set(): + sub = sockets.get("sub") + mon = sockets.get("mon") + stop = events.get("stop") + online = events.get("online") + + poller = zmq.Poller() + poller.register(sub, zmq.POLLIN) + if mon is not None: + poller.register(mon, zmq.POLLIN) + while not stop.is_set(): + zmq_events = {} try: - t, p = socket.recv_multipart() # noqa - t = t.decode("ascii") - if t.startswith(client_id): - socket.unsubscribe(t) - t = t[len(client_id):] - socket.subscribe(t) - p = json.loads(p) - with locks[t]: - d[t] <<= p - except zmq.Again: - # No data received, cycle again - pass + zmq_events = dict(poller.poll(timeout=200)) + except zmq.ZMQError: # pragma: no cover + break + + if sub in zmq_events: + try: + t, p = sub.recv_multipart(flags=zmq.DONTWAIT) # noqa + t = t.decode("ascii") + if t.startswith(client_id): + sub.unsubscribe(t) + t = t[len(client_id):] + sub.subscribe(t) + p = json.loads(p) + with locks[t]: + namespaces[t] <<= p + except zmq.Again: # pragma: no cover + # We should never get here since there will always be + # some data to recover from the socket + pass + + if mon is not None and mon in zmq_events: + while True: + try: + event = recv_monitor_message(mon, flags=zmq.DONTWAIT) + except zmq.Again: + break + + event = event["event"] + if event == zmq.EVENT_CONNECTED: + online.set() + elif event in \ + (zmq.EVENT_DISCONNECTED, zmq.EVENT_CLOSED): + online.clear() + + def __command__(self, cmd: str, *args) -> dict[str, Any]: + if self._online.is_set(): + message = {"command": cmd} + if args: + message["args"] = args + payload = json.dumps(message, separators=(",", ":")) + self._req.send_string(payload) + answer = json.loads(self._req.recv_string()) + else: + answer = { + "executed": False, + "error": { + "type": 2101, # ClientErrors + "code": 14, # DISCOSUnreachableError + "description": "DISCOS is unreachable" + } + } + return answer def __repr__(self) -> str: """ diff --git a/discos_client/servers/srt/server.key b/discos_client/servers/srt/server.key new file mode 100644 index 0000000..a68a264 --- /dev/null +++ b/discos_client/servers/srt/server.key @@ -0,0 +1,9 @@ +# **** Generated on 2025-12-18 10:39:47.317361 by pyzmq **** +# ZeroMQ CURVE Public Certificate +# Exchange securely, or use a secure mechanism to verify the contents +# of this file after exchange. Store public certificates in your home +# directory, in the .curve subdirectory. + +metadata +curve + public-key = "c2u6x}C-+P{P0.K(Fp+Qjj0*x.)IeX}Z}1v" diff --git a/discos_client/utils.py b/discos_client/utils.py index d070c1e..69168cd 100644 --- a/discos_client/utils.py +++ b/discos_client/utils.py @@ -3,6 +3,10 @@ import secrets import string from typing import Any, Callable +from importlib.resources import files +from pathlib import Path +from zmq.auth import load_certificate +from platformdirs import user_config_dir __all__ = [ @@ -10,7 +14,8 @@ "rand_id", "delegated_operations", "delegated_comparisons", - "public_dict" + "public_dict", + "get_auth_keys" ] META_KEYS = ("type", "title", "description", "format", "unit", "enum") @@ -145,3 +150,22 @@ def __unwrap(value: Any, is_fn, get_value_fn) -> Any: while is_fn(value): value = get_value_fn(value) return list(value) if isinstance(value, (list, tuple)) else value + + +def get_auth_keys(telescope: str) -> tuple[bytes]: + """Retrieves the CURVE authentication keys, both for the client and + the desired server. + + :param telescope: The telescope for which the server public key will be + retrieved. + :return: The client's public and secret keys, followed by the server's + public key, as a tuple. + """ + config_base = Path(user_config_dir("discos")) + curve_directory = config_base / "rpc" / "client" + client_pair = curve_directory / "identity.key_secret" + server_pair = files("discos_client") / "servers" \ + / telescope.lower() / "server.key" + client_public, client_secret = load_certificate(client_pair) + server_public, _ = load_certificate(server_pair) + return client_public, client_secret, server_public diff --git a/pyproject.toml b/pyproject.toml index 86ad494..0755857 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,15 @@ classifiers = [ ] dependencies = [ "pyzmq", + "platformdirs" ] [tool.setuptools] packages = ["discos_client"] +script-files = ["scripts/discos-keygen"] [tool.setuptools.package-data] -"discos_client" = ["schemas/**"] +"discos_client" = ["schemas/**", "servers/**"] [project.optional-dependencies] test = ["coverage", "prospector", "jsonschema", "referencing"] diff --git a/scripts/discos-keygen b/scripts/discos-keygen new file mode 100644 index 0000000..db9d4ea --- /dev/null +++ b/scripts/discos-keygen @@ -0,0 +1,69 @@ +#!/usr/bin/env python +import os +import sys +from pathlib import Path +from argparse import ArgumentParser +from platformdirs import user_config_dir +from zmq.auth import create_certificates + +base_config = Path(user_config_dir("discos")) +target_dir = base_config / "rpc" / "client" +key_filename = "identity" +full_path_public = target_dir / f"{key_filename}.key" +full_path_secret = target_dir / f"{key_filename}.key_secret" + +def create_discos_keys(overwrite): + + if full_path_secret.exists() and not overwrite: + print("Kept previously created key pair. Use --overwrite to replace it.\n") + return + + try: + target_dir.mkdir(parents=True, exist_ok=True) + except OSError as e: + print(f"Error creating the configuration directory: {e}") + sys.exit(1) + + create_certificates(str(target_dir), key_filename) + + if os.name == 'posix': + full_path_secret.chmod(0o600) + (target_dir / f"{key_filename}.key").chmod(0o644) + print(f"Key pair created in: '{target_dir}'.") + +def print_discos_keys(): + if not full_path_public.exists(): + print("No key was generated yet.") + return + + with open(full_path_public, "r") as f: + print(f.read()) + print(f"\nPath of the public key file: {full_path_public}") + print(f"Remember to never share the '{key_filename}.key_secret' file with anyone.") + print( + "In order to be authorized to send command to any of the telescopes, " \ + f"remember to send a copy of the '{key_filename}.key' file to the " \ + "DISCOS team, asking for authorization. Your request will be taken " \ + "into consideration and you will hear back from the team." + ) + +if __name__ == "__main__": + parser = ArgumentParser( + "DISCOS CURVE key pairs generator." + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite existing keys. Dafaults to False." + ) + parser.add_argument( + "--show-only", + action="store_true", + help="Only prints the public key and its path without generating a new pair." + ) + args = parser.parse_args() + + if not args.show_only: + create_discos_keys(args.overwrite) + + print_discos_keys() diff --git a/tests/test_client.py b/tests/test_client.py index fd119ac..d49112e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,26 +2,57 @@ import unittest import time import re +from unittest.mock import patch from pathlib import Path from threading import Thread, Event import zmq -from discos_client.client import DISCOSClient, DEFAULT_PORT +from zmq.auth import load_certificate +from zmq.auth.thread import ThreadAuthenticator +from discos_client.client import DISCOSClient, \ + DEFAULT_SUB_PORT, DEFAULT_REQ_PORT + +keys_path = Path(__file__).resolve().parent / "test_keys" +dummy_public, dummy_secret = load_certificate( + keys_path / "dummy.key_secret" +) class TestPublisher: - def __init__(self, telescope=None): + def __init__(self, telescope=None, router=False): self.context = zmq.Context() - self.socket = self.context.socket(zmq.XPUB) - self.socket.setsockopt(zmq.LINGER, 0) - self.socket.setsockopt(zmq.SNDHWM, 10) + self.pub = self.context.socket(zmq.XPUB) + self.pub.setsockopt(zmq.LINGER, 0) + self.pub.setsockopt(zmq.SNDHWM, 10) + self.router = self.context.socket(zmq.ROUTER) + self.router.curve_publickey = dummy_public + self.router.curve_secretkey = dummy_secret + self.router.curve_server = True + self.auth = ThreadAuthenticator(self.context) + self.auth.configure_curve(domain="*", location=keys_path) # This loop is necessary to wait for the client to close between tests while True: try: - self.socket.bind(f"tcp://127.0.0.1:{DEFAULT_PORT}") + self.pub.bind(f"tcp://127.0.0.1:{DEFAULT_SUB_PORT}") break except zmq.ZMQError: - pass + continue + if router: + while True: + try: + self.auth.start() + break + except zmq.ZMQError: + continue + while True: + try: + self.router.bind(f"tcp://127.0.0.1:{DEFAULT_REQ_PORT}") + break + except zmq.ZMQError: + continue + self.poller = zmq.Poller() + self.poller.register(self.pub, zmq.POLLIN) + self.poller.register(self.router, zmq.POLLIN) messages_dir = Path(__file__).resolve().parent / "messages" message_files = list(messages_dir.glob("common/*.json")) if telescope: @@ -46,34 +77,32 @@ def recurse(obj): recurse(item) for payload in self.messages.values(): recurse(payload) - self.t = Thread(target=self.publish) + self.t = Thread(target=self.loop) self.event = Event() self.t.start() def __enter__(self): return self - def _handle_subscription(self): - while True: - try: - event = self.socket.recv(flags=zmq.DONTWAIT) - except zmq.Again: - break - if not event: - continue + def _handle_events(self): + zmq_events = {} + try: + zmq_events = dict(self.poller.poll(timeout=200)) + except zmq.ZMQError: + return + + if self.pub in zmq_events: + event = self.pub.recv(flags=zmq.DONTWAIT) op = event[0] topic = event[1:].decode(errors="ignore") - if op != 1: - continue - - if re.match(r"^[0-9A-Za-z]{4}_.+$", topic): + if op == 1 and re.match(r"^[0-9A-Za-z]{4}_.+$", topic): t = topic.split("_", 1)[1] if t in self.messages: message = json.dumps( self.messages[t], separators=(",", ":") ).encode("utf-8") - self.socket.send_multipart([ + self.pub.send_multipart([ topic.encode("ascii"), message ]) @@ -88,10 +117,24 @@ def _handle_subscription(self): subparts, separators=(",", ":") ).encode("utf-8") - self.socket.send_multipart([ + self.pub.send_multipart([ topic.encode("ascii"), message ]) + if self.router in zmq_events: + req = self.router.recv_multipart(copy=False) + routing_id, sep, payload = (req + [None])[:3] # noqa + payload = json.loads(payload.bytes) + answer = { + "executed": True, + "command": payload["command"] + } + self.router.send_multipart([ + routing_id, + b"", + json.dumps(answer, separators=(",", ":")).encode() + ]) + def _send_periodic_messages(self): for timestamp in self.timestamps: timestamp["unix_time"] = time.time() @@ -103,21 +146,23 @@ def _send_periodic_messages(self): payload, separators=(",", ":") ).encode("utf-8") - self.socket.send_multipart([ + self.pub.send_multipart([ topic.encode("ascii"), payload ]) - def publish(self): + def loop(self): while not self.event.is_set(): - self._handle_subscription() + self._handle_events() self._send_periodic_messages() time.sleep(0.1) def __exit__(self, exc_type, exc_value, traceback): self.event.set() self.t.join() - self.socket.close() + self.pub.close() + self.router.close() + self.auth.stop() self.context.term() @@ -126,16 +171,28 @@ class TestDISCOSClient(unittest.TestCase): def test_no_topics(self): DISCOSClient( address="127.0.0.1", - port=DEFAULT_PORT, + sub_port=DEFAULT_SUB_PORT, telescope="SRT" ) + def test_unknown_telescope(self): + with self.assertRaises(ValueError) as ex: + DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT, + telescope="Unknown" + ) + self.assertEqual( + "Unknown telescope: 'Unknown'", + ex.exception.args[0] + ) + def test_unknown_topic(self): with self.assertRaises(ValueError) as ex: DISCOSClient( "foo", address="127.0.0.1", - port=DEFAULT_PORT + sub_port=DEFAULT_SUB_PORT ) self.assertTrue( "Topic 'foo' is not known" in ex.exception.args[0] @@ -144,28 +201,37 @@ def test_unknown_topic(self): DISCOSClient( "foo", "bar", address="127.0.0.1", - port=DEFAULT_PORT, + sub_port=DEFAULT_SUB_PORT, ) self.assertTrue( "Topics 'foo' and 'bar' are not known" in ex.exception.args[0] ) def test_repr(self): - client = DISCOSClient(address="127.0.0.1", port=DEFAULT_PORT) + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT + ) self.assertTrue( repr(client).startswith("") ) def test_str(self): - client = DISCOSClient(address="127.0.0.1", port=DEFAULT_PORT) + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT + ) self.assertTrue( str(client).startswith("{") and str(client).endswith("}") ) def test_format(self): - client = DISCOSClient(address="127.0.0.1", port=DEFAULT_PORT) + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT + ) self.assertTrue( f"{client:}".startswith("{") and f"{client:}".endswith("}") @@ -216,7 +282,10 @@ def test_format(self): def test_bind(self): with TestPublisher("SRT"): - client = DISCOSClient(address="127.0.0.1", port=DEFAULT_PORT) + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT + ) s = set() called = set() s.add(id(client.antenna.timestamp.unix_time)) @@ -240,7 +309,10 @@ def callback(value): def test_wait(self): with TestPublisher(): - client = DISCOSClient(address="127.0.0.1", port=DEFAULT_PORT) + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT + ) unix_time = client.antenna.timestamp.unix_time.copy() antenna = client.antenna.copy() self.assertNotEqual( @@ -252,6 +324,71 @@ def test_wait(self): client.antenna.wait(timeout=5) ) + @patch("discos_client.utils.load_certificate") + def test_command(self, mock_load_cert): + mock_load_cert.return_value = (dummy_public, dummy_secret) + with TestPublisher(router=True): + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT, + req_port=DEFAULT_REQ_PORT, + telescope="SRT" + ) + self.assertTrue(hasattr(client, "command")) + self.assertTrue(hasattr(client, "_online")) + while not client._online.is_set(): + time.sleep(0.01) + answer = client.command("dummy") + self.assertTrue(answer["executed"]) + + @patch("discos_client.utils.load_certificate") + def test_command_with_args(self, mock_load_cert): + mock_load_cert.return_value = (dummy_public, dummy_secret) + with TestPublisher(router=True): + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT, + req_port=DEFAULT_REQ_PORT, + telescope="SRT" + ) + self.assertTrue(hasattr(client, "command")) + self.assertTrue(hasattr(client, "_online")) + while not client._online.is_set(): + time.sleep(0.01) + answer = client.command("dummy", 1, 2, 3) + self.assertTrue(answer["executed"]) + + @patch("discos_client.utils.load_certificate") + def test_command_unreachable(self, mock_load_cert): + mock_load_cert.return_value = (dummy_public, dummy_secret) + with TestPublisher(): + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT, + req_port=DEFAULT_REQ_PORT, + telescope="SRT" + ) + self.assertTrue(hasattr(client, "command")) + self.assertFalse(client.command("dummy")["executed"]) + + def test_command_not_present(self): + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT + ) + self.assertFalse(hasattr(client, "command")) + + @patch("discos_client.utils.load_certificate") + def test_command_keys_not_present(self, mock_load_cert): + mock_load_cert.side_effect = OSError + client = DISCOSClient( + address="127.0.0.1", + sub_port=DEFAULT_SUB_PORT, + req_port=DEFAULT_REQ_PORT, + telescope="SRT" + ) + self.assertFalse(hasattr(client, "command")) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_keys/dummy.key b/tests/test_keys/dummy.key new file mode 100644 index 0000000..a2598fa --- /dev/null +++ b/tests/test_keys/dummy.key @@ -0,0 +1,7 @@ +# **** Generated on 2026-01-05 22:54:36.363259 by pyzmq **** +# ZeroMQ CURVE **Secret** Certificate +# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions. + +metadata +curve + public-key = "f!G:wZysatP7c3Pbicu42Ng]ttQ}HWeqeQ(TJ(rq" diff --git a/tests/test_keys/dummy.key_secret b/tests/test_keys/dummy.key_secret new file mode 100644 index 0000000..c0e0d4e --- /dev/null +++ b/tests/test_keys/dummy.key_secret @@ -0,0 +1,8 @@ +# **** Generated on 2026-01-05 22:54:36.363259 by pyzmq **** +# ZeroMQ CURVE **Secret** Certificate +# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions. + +metadata +curve + public-key = "f!G:wZysatP7c3Pbicu42Ng]ttQ}HWeqeQ(TJ(rq" + secret-key = "?bBw]p8$%gJPUKi3Kf#CSkuq63u6YrUY Date: Wed, 7 Jan 2026 14:01:32 +0100 Subject: [PATCH 2/6] Issue #19, fixed documentation issue --- docs/conf.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 0d415b3..0d38709 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -64,31 +64,6 @@ html_static_path = ['_static'] -class SkipMembersClassDocumenter(ClassDocumenter): - objtype = "class" - - option_spec = ClassDocumenter.option_spec.copy() - option_spec["skip-members"] = lambda arg: [name.strip() for name in arg.split(",")] if arg else [] - - def parse_name(self): - return super().parse_name() - - -def skip_special_members(app, what, name, obj, skip, options): - global_skips = app.config.autoclass_skip_members_default or set() - directive_skips = set(options.get("skip-members", [])) - - if name in global_skips or name in directive_skips: - return True - return None - - -def setup(app): - app.add_config_value("autoclass_skip_members_default", set(), "env") - app.add_autodocumenter(SkipMembersClassDocumenter, override=True) - app.connect("autodoc-skip-member", skip_special_members) - - sjs_wide_format = importlib.import_module("sphinx-jsonschema.wide_format") sjs_wide_format.WideFormat._simpletype = _simpletype sjs_wide_format.WideFormat._complexstructures = _complexstructures From 983837940053f19c580adcdf8705c11ad73bf21c Mon Sep 17 00:00:00 2001 From: Giuseppe Carboni Date: Wed, 7 Jan 2026 14:06:38 +0100 Subject: [PATCH 3/6] Issue #19, fixed another documentation issue --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0755857..a0e8ba9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,4 +28,4 @@ script-files = ["scripts/discos-keygen"] [project.optional-dependencies] test = ["coverage", "prospector", "jsonschema", "referencing"] -docs = ["sphinx", "sphinx-rtd-theme", "sphinx-autodoc-typehints", "sphinx-jsonschema"] +docs = ["sphinx<9,>=6", "sphinx-rtd-theme", "sphinx-autodoc-typehints", "sphinx-jsonschema"] From d4581c5f67d8276ce92e2b6e981b52862c319920 Mon Sep 17 00:00:00 2001 From: Giuseppe Carboni Date: Wed, 7 Jan 2026 15:19:42 +0100 Subject: [PATCH 4/6] Issue #19, fixed for windows --- .coveragerc | 1 + scripts/discos-keygen => discos_client/cli.py | 6 ++---- pyproject.toml | 8 +++++--- 3 files changed, 8 insertions(+), 7 deletions(-) rename scripts/discos-keygen => discos_client/cli.py (96%) diff --git a/.coveragerc b/.coveragerc index 473d62c..5a0bb03 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,7 @@ [run] concurrency = thread source = discos_client +omit = discos_client/cli.py [paths] discos_client = diff --git a/scripts/discos-keygen b/discos_client/cli.py similarity index 96% rename from scripts/discos-keygen rename to discos_client/cli.py index db9d4ea..6dea55f 100644 --- a/scripts/discos-keygen +++ b/discos_client/cli.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python import os import sys from pathlib import Path @@ -47,7 +46,7 @@ def print_discos_keys(): "into consideration and you will hear back from the team." ) -if __name__ == "__main__": +def main(): parser = ArgumentParser( "DISCOS CURVE key pairs generator." ) @@ -65,5 +64,4 @@ def print_discos_keys(): if not args.show_only: create_discos_keys(args.overwrite) - - print_discos_keys() + print_discos_keys() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a0e8ba9..6123871 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,11 +21,13 @@ dependencies = [ [tool.setuptools] packages = ["discos_client"] -script-files = ["scripts/discos-keygen"] + +[project.scripts] +discos-keygen = "discos_client.cli:main" [tool.setuptools.package-data] -"discos_client" = ["schemas/**", "servers/**"] +discos_client = ["schemas/**", "servers/**"] [project.optional-dependencies] -test = ["coverage", "prospector", "jsonschema", "referencing"] +test = ["coverage", "prospector", "jsonschema", "referencing", "tornado"] docs = ["sphinx<9,>=6", "sphinx-rtd-theme", "sphinx-autodoc-typehints", "sphinx-jsonschema"] From 0701192ade7442406f3792827ca005816f559aab Mon Sep 17 00:00:00 2001 From: Giuseppe Carboni Date: Wed, 7 Jan 2026 16:02:04 +0100 Subject: [PATCH 5/6] Issue #19, attempt to fix windows tests --- discos_client/cli.py | 32 +++++++++++++++++++------------- pyproject.toml | 2 +- tests/test_client.py | 6 ++++++ 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/discos_client/cli.py b/discos_client/cli.py index 6dea55f..f2274a1 100644 --- a/discos_client/cli.py +++ b/discos_client/cli.py @@ -7,14 +7,16 @@ base_config = Path(user_config_dir("discos")) target_dir = base_config / "rpc" / "client" -key_filename = "identity" -full_path_public = target_dir / f"{key_filename}.key" -full_path_secret = target_dir / f"{key_filename}.key_secret" +KEY_FILENAME = "identity" +full_path_public = target_dir / f"{KEY_FILENAME}.key" +full_path_secret = target_dir / f"{KEY_FILENAME}.key_secret" + def create_discos_keys(overwrite): if full_path_secret.exists() and not overwrite: - print("Kept previously created key pair. Use --overwrite to replace it.\n") + print("Kept previously created key pair. " + "Use --overwrite to replace it.\n") return try: @@ -23,29 +25,32 @@ def create_discos_keys(overwrite): print(f"Error creating the configuration directory: {e}") sys.exit(1) - create_certificates(str(target_dir), key_filename) + create_certificates(str(target_dir), KEY_FILENAME) if os.name == 'posix': full_path_secret.chmod(0o600) - (target_dir / f"{key_filename}.key").chmod(0o644) + (target_dir / f"{KEY_FILENAME}.key").chmod(0o644) print(f"Key pair created in: '{target_dir}'.") + def print_discos_keys(): if not full_path_public.exists(): print("No key was generated yet.") return - with open(full_path_public, "r") as f: + with open(full_path_public, "r", encoding="utf-8") as f: print(f.read()) print(f"\nPath of the public key file: {full_path_public}") - print(f"Remember to never share the '{key_filename}.key_secret' file with anyone.") + print(f"Remember to never share the '{KEY_FILENAME}.key_secret' file with " + "anyone.") print( - "In order to be authorized to send command to any of the telescopes, " \ - f"remember to send a copy of the '{key_filename}.key' file to the " \ - "DISCOS team, asking for authorization. Your request will be taken " \ + "In order to be authorized to send command to any of the telescopes, " + f"remember to send a copy of the '{KEY_FILENAME}.key' file to the " + "DISCOS team, asking for authorization. Your request will be taken " "into consideration and you will hear back from the team." ) + def main(): parser = ArgumentParser( "DISCOS CURVE key pairs generator." @@ -58,10 +63,11 @@ def main(): parser.add_argument( "--show-only", action="store_true", - help="Only prints the public key and its path without generating a new pair." + help="Only prints the public key and its path without \ + generating a new pair." ) args = parser.parse_args() if not args.show_only: create_discos_keys(args.overwrite) - print_discos_keys() \ No newline at end of file + print_discos_keys() diff --git a/pyproject.toml b/pyproject.toml index 6123871..040e53a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,5 +29,5 @@ discos-keygen = "discos_client.cli:main" discos_client = ["schemas/**", "servers/**"] [project.optional-dependencies] -test = ["coverage", "prospector", "jsonschema", "referencing", "tornado"] +test = ["coverage", "prospector", "jsonschema", "referencing"] docs = ["sphinx<9,>=6", "sphinx-rtd-theme", "sphinx-autodoc-typehints", "sphinx-jsonschema"] diff --git a/tests/test_client.py b/tests/test_client.py index d49112e..71d4cca 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,6 +2,8 @@ import unittest import time import re +import asyncio +import sys from unittest.mock import patch from pathlib import Path from threading import Thread, Event @@ -11,6 +13,10 @@ from discos_client.client import DISCOSClient, \ DEFAULT_SUB_PORT, DEFAULT_REQ_PORT + +if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + keys_path = Path(__file__).resolve().parent / "test_keys" dummy_public, dummy_secret = load_certificate( keys_path / "dummy.key_secret" From 70b33216a3196454c9d4867ade06c6618465fea5 Mon Sep 17 00:00:00 2001 From: Giuseppe Carboni Date: Wed, 7 Jan 2026 23:47:47 +0100 Subject: [PATCH 6/6] Issue #19, added tests for discos-keygen --- .coveragerc | 1 - .gitignore | 6 +- discos_client/{cli.py => scripts.py} | 57 +++++++++------- pyproject.toml | 2 +- tests/test_scripts.py | 98 ++++++++++++++++++++++++++++ 5 files changed, 136 insertions(+), 28 deletions(-) rename discos_client/{cli.py => scripts.py} (52%) create mode 100644 tests/test_scripts.py diff --git a/.coveragerc b/.coveragerc index 5a0bb03..473d62c 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,7 +1,6 @@ [run] concurrency = thread source = discos_client -omit = discos_client/cli.py [paths] discos_client = diff --git a/.gitignore b/.gitignore index e14c9dc..69a2b47 100644 --- a/.gitignore +++ b/.gitignore @@ -182,9 +182,9 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder # .vscode/ @@ -206,4 +206,4 @@ marimo/_static/ marimo/_lsp/ __marimo__/ - +**.swp diff --git a/discos_client/cli.py b/discos_client/scripts.py similarity index 52% rename from discos_client/cli.py rename to discos_client/scripts.py index f2274a1..a7819cd 100644 --- a/discos_client/cli.py +++ b/discos_client/scripts.py @@ -1,57 +1,64 @@ import os -import sys from pathlib import Path from argparse import ArgumentParser from platformdirs import user_config_dir from zmq.auth import create_certificates -base_config = Path(user_config_dir("discos")) -target_dir = base_config / "rpc" / "client" -KEY_FILENAME = "identity" -full_path_public = target_dir / f"{KEY_FILENAME}.key" -full_path_secret = target_dir / f"{KEY_FILENAME}.key_secret" + +def get_config_paths(): + base_config = Path(user_config_dir("discos")) + config_dir = base_config / "rpc" / "client" + public = config_dir / "identity.key" + secret = config_dir / "identity.key_secret" + return config_dir, public, secret def create_discos_keys(overwrite): + config_dir, public, secret = get_config_paths() - if full_path_secret.exists() and not overwrite: + if secret.exists() and not overwrite: print("Kept previously created key pair. " "Use --overwrite to replace it.\n") - return + return 0 try: - target_dir.mkdir(parents=True, exist_ok=True) + config_dir.mkdir(parents=True, exist_ok=True) except OSError as e: print(f"Error creating the configuration directory: {e}") - sys.exit(1) + return 1 - create_certificates(str(target_dir), KEY_FILENAME) + create_certificates(str(config_dir), "identity") if os.name == 'posix': - full_path_secret.chmod(0o600) - (target_dir / f"{KEY_FILENAME}.key").chmod(0o644) - print(f"Key pair created in: '{target_dir}'.") + public.chmod(0o644) + secret.chmod(0o600) + print(f"Key pair created in: '{config_dir}'.") + return 0 def print_discos_keys(): - if not full_path_public.exists(): + _, public, _ = get_config_paths() + + if not public.exists(): print("No key was generated yet.") - return + return 0 - with open(full_path_public, "r", encoding="utf-8") as f: + with open(public, "r", encoding="utf-8") as f: print(f.read()) - print(f"\nPath of the public key file: {full_path_public}") - print(f"Remember to never share the '{KEY_FILENAME}.key_secret' file with " + + print(f"\nPath of the public key file: {public}") + print("Remember to never share the 'identity.key_secret' file with " "anyone.") print( "In order to be authorized to send command to any of the telescopes, " - f"remember to send a copy of the '{KEY_FILENAME}.key' file to the " + "remember to send a copy of the 'identity.key' file to the " "DISCOS team, asking for authorization. Your request will be taken " "into consideration and you will hear back from the team." ) + return 0 -def main(): +def keygen(): parser = ArgumentParser( "DISCOS CURVE key pairs generator." ) @@ -69,5 +76,9 @@ def main(): args = parser.parse_args() if not args.show_only: - create_discos_keys(args.overwrite) - print_discos_keys() + return_code = create_discos_keys(args.overwrite) + + if return_code != 0: + return return_code + + return print_discos_keys() diff --git a/pyproject.toml b/pyproject.toml index 040e53a..356a567 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ packages = ["discos_client"] [project.scripts] -discos-keygen = "discos_client.cli:main" +discos-keygen = "discos_client.scripts:keygen" [tool.setuptools.package-data] discos_client = ["schemas/**", "servers/**"] diff --git a/tests/test_scripts.py b/tests/test_scripts.py new file mode 100644 index 0000000..cf22c3d --- /dev/null +++ b/tests/test_scripts.py @@ -0,0 +1,98 @@ +import unittest +import sys +import shutil +import tempfile +from io import StringIO +from pathlib import Path +from unittest.mock import patch, MagicMock +from platformdirs import user_config_dir +from discos_client import scripts + + +class TestKeygen(unittest.TestCase): + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.test_path = Path(self.test_dir) + self.mock_target_dir = self.test_path / "rpc" / "client" + self.mock_public = self.mock_target_dir / "identity.key" + self.mock_secret = self.mock_target_dir / "identity.key_secret" + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def test_correct_paths(self): + config_dir, public, secret = scripts.get_config_paths() + expected_config_dir = \ + Path(user_config_dir("discos")) / "rpc" / "client" + expected_public = expected_config_dir / "identity.key" + expected_secret = expected_config_dir / "identity.key_secret" + self.assertEqual(config_dir, expected_config_dir) + self.assertEqual(public, expected_public) + self.assertEqual(secret, expected_secret) + + @patch("discos_client.scripts.get_config_paths") + @patch("sys.stdout", new_callable=StringIO) + @patch.object(sys, "argv", ["discos-keygen"]) + def test_keygen(self, mock_stdout, mock_paths): + mock_paths.return_value = ( + self.mock_target_dir, + self.mock_public, + self.mock_secret + ) + rc = scripts.keygen() + self.assertEqual(rc, 0) + self.assertTrue(self.mock_public.exists()) + self.assertTrue(self.mock_secret.exists()) + output = mock_stdout.getvalue() + self.assertIn("Key pair created in", output) + + @patch("discos_client.scripts.get_config_paths") + @patch("sys.stdout", new_callable=StringIO) + @patch.object(sys, "argv", ["discos-keygen"]) + def test_keygen_no_overwrite(self, mock_stdout, mock_paths): + mock_paths.return_value = ( + self.mock_target_dir, + self.mock_public, + self.mock_secret + ) + self.assertEqual(scripts.keygen(), 0) + self.assertTrue(self.mock_public.exists()) + self.assertTrue(self.mock_secret.exists()) + output = mock_stdout.getvalue() + self.assertIn("Key pair created in", output) + self.assertEqual(scripts.keygen(), 0) + output = mock_stdout.getvalue() + self.assertIn("Kept previously created key pair", output) + + @patch("discos_client.scripts.get_config_paths") + @patch("sys.stdout", new_callable=StringIO) + def test_print_keys(self, mock_stdout, mock_paths): + mock_paths.return_value = ( + self.mock_target_dir, + self.mock_public, + self.mock_secret + ) + scripts.print_discos_keys() + output = mock_stdout.getvalue() + self.assertIn("No key was generated yet.", output) + + @patch("discos_client.scripts.get_config_paths") + @patch("sys.stdout", new_callable=StringIO) + @patch.object(sys, "argv", ["discos-keygen"]) + def test_mkdir_error(self, mock_stdout, mock_paths): + mock_target_dir = MagicMock() + mock_target_dir.mkdir.side_effect = OSError("Test error") + mock_paths.return_value = ( + mock_target_dir, + self.mock_public, + self.mock_secret + ) + rc = scripts.keygen() + self.assertEqual(rc, 1) + output = mock_stdout.getvalue() + self.assertIn("Error creating the configuration directory", output) + + +if __name__ == "__main__": + unittest.main()