From 7d2c7181fcf8fa8796533136563ff5c4f872041c Mon Sep 17 00:00:00 2001 From: Yannick Marcon Date: Sat, 21 Mar 2026 18:54:27 +0100 Subject: [PATCH 1/2] feat: added catalog functions to list and search variables (#5) * feat: added catalog functions to list and search variables * chore: prepare next version * feat: doc updated * fix: search returned type * chore: code review --- Makefile | 4 ++ datashield/api.py | 132 +++++++++++++++++++++++++++------------- datashield/interface.py | 38 ++++++++++++ pyproject.toml | 2 +- uv.lock | 40 ++++++------ 5 files changed, 153 insertions(+), 63 deletions(-) diff --git a/Makefile b/Makefile index d48a08e..71cc7f4 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,10 @@ install: uv sync --all-extras +update: + rm -f uv.lock + uv sync + test: uv run --all-extras pytest diff --git a/datashield/api.py b/datashield/api.py index b2cf4b6..59ec3a0 100644 --- a/datashield/api.py +++ b/datashield/api.py @@ -140,7 +140,7 @@ def close(self, save: str = None) -> None: for conn in self.conns: try: if save: - conn.save_workspace(f"{conn.name}:{save}") + conn.save_workspace(f"{conn.get_name()}:{save}") conn.disconnect() except DSError: # silently fail @@ -162,7 +162,7 @@ def get_connection_names(self) -> list[str]: :return: The list of opened connection names """ if self.conns: - return [conn.name for conn in self.conns] + return [conn.get_name() for conn in self.conns] else: return [] @@ -194,7 +194,53 @@ def tables(self) -> dict: """ rval = {} for conn in self.conns: - rval[conn.name] = conn.list_tables() + rval[conn.get_name()] = conn.list_tables() + return rval + + def variables(self, table: str = None, tables: dict = None) -> dict: + """ + List available variables from the data repository, for a given table. + + :param table: The default name of the table to list variables for + :param tables: The name of the table to list variables for, per server name. If not defined, 'table' is used. + :return: The available variables from the data repository, for a given table, per remote server name + """ + rval = {} + for conn in self.conns: + name = table + if tables and conn.get_name() in tables: + name = tables[conn.get_name()] + if name: + rval[conn.get_name()] = conn.list_table_variables(name) + else: + rval[conn.get_name()] = None + return rval + + def taxonomies(self) -> dict: + """ + List available taxonomies from the data repository. A taxonomy is a hierarchical structure of vocabulary + terms that can be used to annotate variables in the data repository. + Depending on the data repository's capabilities, taxonomies can be used to perform structured + queries when searching for variables. + + :return: The available taxonomies from the data repository, per remote server name + """ + rval = {} + for conn in self.conns: + rval[conn.get_name()] = conn.list_taxonomies() + return rval + + def search_variables(self, query: str) -> dict: + """ + Search for variable names matching a given query across all tables in the data repository. + + :param query: The query to search for in variable names, e.g., a full-text search and/or structured + query (based on taxonomy terms), depending on the data repository's capabilities + :return: The matching variable names from the data repository, per remote server name + """ + rval = {} + for conn in self.conns: + rval[conn.get_name()] = conn.search_variables(query) return rval def resources(self) -> dict: @@ -205,7 +251,7 @@ def resources(self) -> dict: """ rval = {} for conn in self.conns: - rval[conn.name] = conn.list_resources() + rval[conn.get_name()] = conn.list_resources() return rval def profiles(self) -> dict: @@ -216,7 +262,7 @@ def profiles(self) -> dict: """ rval = {} for conn in self.conns: - rval[conn.name] = conn.list_profiles() + rval[conn.get_name()] = conn.list_profiles() return rval def packages(self) -> dict: @@ -227,7 +273,7 @@ def packages(self) -> dict: """ rval = {} for conn in self.conns: - rval[conn.name] = conn.list_packages() + rval[conn.get_name()] = conn.list_packages() return rval def methods(self, type: str = "aggregate") -> dict: @@ -239,7 +285,7 @@ def methods(self, type: str = "aggregate") -> dict: """ rval = {} for conn in self.conns: - rval[conn.name] = conn.list_methods(type) + rval[conn.get_name()] = conn.list_methods(type) return rval # @@ -254,7 +300,7 @@ def workspaces(self) -> dict: """ rval = {} for conn in self.conns: - rval[conn.name] = conn.list_workspaces() + rval[conn.get_name()] = conn.list_workspaces() return rval def workspace_save(self, name: str) -> None: @@ -264,7 +310,7 @@ def workspace_save(self, name: str) -> None: :param name: The name of the workspace """ for conn in self.conns: - conn.save_workspace(f"{conn.name}:{name}") + conn.save_workspace(f"{conn.get_name()}:{name}") def workspace_restore(self, name: str) -> None: """ @@ -274,7 +320,7 @@ def workspace_restore(self, name: str) -> None: :param name: The name of the workspace """ for conn in self.conns: - conn.restore_workspace(f"{conn.name}:{name}") + conn.restore_workspace(f"{conn.get_name()}:{name}") def workspace_rm(self, name: str) -> None: """ @@ -284,7 +330,7 @@ def workspace_rm(self, name: str) -> None: :param name: The name of the workspace """ for conn in self.conns: - conn.rm_workspace(f"{conn.name}:{name}") + conn.rm_workspace(f"{conn.get_name()}:{name}") # # R session @@ -321,17 +367,17 @@ def sessions(self) -> dict: if not conn.has_session(): conn.start_session(asynchronous=True) except Exception as e: - logging.warning(f"Failed to start session: {conn.name} - {e}") - excluded_conns.append(conn.name) + logging.warning(f"Failed to start session: {conn.get_name()} - {e}") + excluded_conns.append(conn.get_name()) # check for session status and wait until all are started - for conn in [c for c in self.conns if c.name not in excluded_conns]: + for conn in [c for c in self.conns if c.get_name() not in excluded_conns]: try: if conn.is_session_started(): - started_conns.append(conn.name) + started_conns.append(conn.get_name()) except Exception as e: - logging.warning(f"Failed to check session status: {conn.name} - {e}") - excluded_conns.append(conn.name) + logging.warning(f"Failed to check session status: {conn.get_name()} - {e}") + excluded_conns.append(conn.get_name()) # wait until all sessions are started, excluding those that have failed to start or check status start_time = time.time() @@ -340,23 +386,25 @@ def sessions(self) -> dict: raise DSError("Timed out waiting for R sessions to start") time.sleep(self.start_delay) remaining_conns = [ - conn for conn in self.conns if conn.name not in started_conns and conn.name not in excluded_conns + conn + for conn in self.conns + if conn.get_name() not in started_conns and conn.get_name() not in excluded_conns ] for conn in remaining_conns: try: if conn.is_session_started(): - started_conns.append(conn.name) + started_conns.append(conn.get_name()) except Exception as e: - logging.warning(f"Failed to check session status: {conn.name} - {e}") - excluded_conns.append(conn.name) + logging.warning(f"Failed to check session status: {conn.get_name()} - {e}") + excluded_conns.append(conn.get_name()) # at this point, all sessions that could be started have been started, and those that failed to start or check status have been excluded for conn in self.conns: - if conn.name in started_conns: - rval[conn.name] = conn.get_session() + if conn.get_name() in started_conns: + rval[conn.get_name()] = conn.get_session() if len(excluded_conns) > 0: logging.error(f"Some sessions have been excluded due to errors: {', '.join(excluded_conns)}") - self.conns = [conn for conn in self.conns if conn.name not in excluded_conns] + self.conns = [conn for conn in self.conns if conn.get_name() not in excluded_conns] if len(self.conns) == 0: raise DSError("No sessions could be started successfully.") return rval @@ -372,10 +420,10 @@ def ls(self) -> dict: rval = {} for conn in self.conns: try: - rval[conn.name] = conn.list_symbols() + rval[conn.get_name()] = conn.list_symbols() except Exception as e: self._append_error(conn, e) - rval[conn.name] = None + rval[conn.get_name()] = None self._check_errors() return rval @@ -418,12 +466,12 @@ def assign_table( cmd = {} for conn in self.conns: name = table - if tables and conn.name in tables: - name = tables[conn.name] + if tables and conn.get_name() in tables: + name = tables[conn.get_name()] if name: try: res = conn.assign_table(symbol, name, variables, missings, identifiers, id_name, asynchronous) - cmd[conn.name] = res + cmd[conn.get_name()] = res except Exception as e: self._append_error(conn, e) self._do_wait(cmd) @@ -445,12 +493,12 @@ def assign_resource( cmd = {} for conn in self.conns: name = resource - if resources and conn.name in resources: - name = resources[conn.name] + if resources and conn.get_name() in resources: + name = resources[conn.get_name()] if name: try: res = conn.assign_resource(symbol, name, asynchronous) - cmd[conn.name] = res + cmd[conn.get_name()] = res except Exception as e: self._append_error(conn, e) self._do_wait(cmd) @@ -470,7 +518,7 @@ def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None for conn in self.conns: try: res = conn.assign_expr(symbol, expr, asynchronous) - cmd[conn.name] = res + cmd[conn.get_name()] = res except Exception as e: self._append_error(conn, e) self._do_wait(cmd) @@ -492,10 +540,10 @@ def aggregate(self, expr: str, asynchronous: bool = True) -> dict: for conn in self.conns: try: res = conn.aggregate(expr, asynchronous) - cmd[conn.name] = res + cmd[conn.get_name()] = res except Exception as e: self._append_error(conn, e) - rval[conn.name] = None + rval[conn.get_name()] = None rval = self._do_wait(cmd) self._check_errors() return rval @@ -511,15 +559,15 @@ def _do_wait(self, cmd: dict) -> dict: rval = {} while cmd: for conn in self.conns: - if conn.name in cmd: - res = cmd[conn.name] - # print(f"..checking {conn.name} -> {res.is_completed()}") + if conn.get_name() in cmd: + res = cmd[conn.get_name()] + # print(f"..checking {conn.get_name()} -> {res.is_completed()}") if res.is_completed(): try: - rval[conn.name] = res.fetch() + rval[conn.get_name()] = res.fetch() except Exception as e: self._append_error(conn, e) - cmd.pop(conn.name, None) + cmd.pop(conn.get_name(), None) else: conn.keep_alive() time.sleep(0.1) @@ -535,8 +583,8 @@ def _append_error(self, conn: DSConnection, error: Exception) -> None: """ Append an error. """ - logging.error(f"[{conn.name}] {error}") - self.errors[conn.name] = error + logging.error(f"[{conn.get_name()}] {error}") + self.errors[conn.get_name()] = error def _check_errors(self) -> None: """ diff --git a/datashield/interface.py b/datashield/interface.py index 084d741..354875f 100644 --- a/datashield/interface.py +++ b/datashield/interface.py @@ -194,6 +194,14 @@ class DSConnection: Connection class to a DataSHIELD server. """ + def get_name(self) -> str: + """ + Get the name of the connection, which is typically the name of the server or data repository. + + :return: The name of the connection + """ + raise NotImplementedError("DSConnection function not available") + # # Content listing # @@ -215,6 +223,36 @@ def has_table(self, name: str) -> bool: """ raise NotImplementedError("DSConnection function not available") + def list_table_variables(self, table: str) -> list: + """ + List available variables for a given table from the data repository. + + :param table: The name of the table to list variables for + :return: The list of available variables for the given table + """ + raise NotImplementedError("DSConnection function not available") + + def list_taxonomies(self) -> list: + """ + List available taxonomies from the data repository. A taxonomy is a hierarchical structure of vocabulary + terms that can be used to annotate variables in the data repository. + Depending on the data repository's capabilities, taxonomies can be used to perform structured + queries when searching for variables. + + :return: The list of available taxonomy names + """ + raise NotImplementedError("DSConnection function not available") + + def search_variables(self, query: str) -> dict: + """ + Search for variable names matching a given query across all tables in the data repository. + + :param query: The query to search for in variable names, e.g., a full-text search and/or structured + query (based on taxonomy terms), depending on the data repository's capabilities + :return: The search result for variables matching the given query across all tables + """ + raise NotImplementedError("DSConnection function not available") + def list_resources(self) -> list: """ List available resource names from the data repository. diff --git a/pyproject.toml b/pyproject.toml index f893173..529b7e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "datashield" -version = "0.3.0" +version = "0.4.0" description = "DataSHIELD Client Interface in Python." authors = [ {name = "Yannick Marcon", email = "yannick.marcon@obiba.org"} diff --git a/uv.lock b/uv.lock index 4f7a97e..b68e040 100644 --- a/uv.lock +++ b/uv.lock @@ -22,7 +22,7 @@ wheels = [ [[package]] name = "datashield" -version = "0.3.0" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "pydantic" }, @@ -311,27 +311,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.15.1" +version = "0.15.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/04/dc/4e6ac71b511b141cf626357a3946679abeba4cf67bc7cc5a17920f31e10d/ruff-0.15.1.tar.gz", hash = "sha256:c590fe13fb57c97141ae975c03a1aedb3d3156030cabd740d6ff0b0d601e203f", size = 4540855, upload-time = "2026-02-12T23:09:09.998Z" } +sdist = { url = "https://files.pythonhosted.org/packages/06/04/eab13a954e763b0606f460443fcbf6bb5a0faf06890ea3754ff16523dce5/ruff-0.15.2.tar.gz", hash = "sha256:14b965afee0969e68bb871eba625343b8673375f457af4abe98553e8bbb98342", size = 4558148, upload-time = "2026-02-19T22:32:20.271Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/23/bf/e6e4324238c17f9d9120a9d60aa99a7daaa21204c07fcd84e2ef03bb5fd1/ruff-0.15.1-py3-none-linux_armv6l.whl", hash = "sha256:b101ed7cf4615bda6ffe65bdb59f964e9f4a0d3f85cbf0e54f0ab76d7b90228a", size = 10367819, upload-time = "2026-02-12T23:09:03.598Z" }, - { url = "https://files.pythonhosted.org/packages/b3/ea/c8f89d32e7912269d38c58f3649e453ac32c528f93bb7f4219258be2e7ed/ruff-0.15.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:939c995e9277e63ea632cc8d3fae17aa758526f49a9a850d2e7e758bfef46602", size = 10798618, upload-time = "2026-02-12T23:09:22.928Z" }, - { url = "https://files.pythonhosted.org/packages/5e/0f/1d0d88bc862624247d82c20c10d4c0f6bb2f346559d8af281674cf327f15/ruff-0.15.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1d83466455fdefe60b8d9c8df81d3c1bbb2115cede53549d3b522ce2bc703899", size = 10148518, upload-time = "2026-02-12T23:08:58.339Z" }, - { url = "https://files.pythonhosted.org/packages/f5/c8/291c49cefaa4a9248e986256df2ade7add79388fe179e0691be06fae6f37/ruff-0.15.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9457e3c3291024866222b96108ab2d8265b477e5b1534c7ddb1810904858d16", size = 10518811, upload-time = "2026-02-12T23:09:31.865Z" }, - { url = "https://files.pythonhosted.org/packages/c3/1a/f5707440e5ae43ffa5365cac8bbb91e9665f4a883f560893829cf16a606b/ruff-0.15.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:92c92b003e9d4f7fbd33b1867bb15a1b785b1735069108dfc23821ba045b29bc", size = 10196169, upload-time = "2026-02-12T23:09:17.306Z" }, - { url = "https://files.pythonhosted.org/packages/2a/ff/26ddc8c4da04c8fd3ee65a89c9fb99eaa5c30394269d424461467be2271f/ruff-0.15.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fe5c41ab43e3a06778844c586251eb5a510f67125427625f9eb2b9526535779", size = 10990491, upload-time = "2026-02-12T23:09:25.503Z" }, - { url = "https://files.pythonhosted.org/packages/fc/00/50920cb385b89413f7cdb4bb9bc8fc59c1b0f30028d8bccc294189a54955/ruff-0.15.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66a6dd6df4d80dc382c6484f8ce1bcceb55c32e9f27a8b94c32f6c7331bf14fb", size = 11843280, upload-time = "2026-02-12T23:09:19.88Z" }, - { url = "https://files.pythonhosted.org/packages/5d/6d/2f5cad8380caf5632a15460c323ae326f1e1a2b5b90a6ee7519017a017ca/ruff-0.15.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a4a42cbb8af0bda9bcd7606b064d7c0bc311a88d141d02f78920be6acb5aa83", size = 11274336, upload-time = "2026-02-12T23:09:14.907Z" }, - { url = "https://files.pythonhosted.org/packages/a3/1d/5f56cae1d6c40b8a318513599b35ea4b075d7dc1cd1d04449578c29d1d75/ruff-0.15.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ab064052c31dddada35079901592dfba2e05f5b1e43af3954aafcbc1096a5b2", size = 11137288, upload-time = "2026-02-12T23:09:07.475Z" }, - { url = "https://files.pythonhosted.org/packages/cd/20/6f8d7d8f768c93b0382b33b9306b3b999918816da46537d5a61635514635/ruff-0.15.1-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5631c940fe9fe91f817a4c2ea4e81f47bee3ca4aa646134a24374f3c19ad9454", size = 11070681, upload-time = "2026-02-12T23:08:55.43Z" }, - { url = "https://files.pythonhosted.org/packages/9a/67/d640ac76069f64cdea59dba02af2e00b1fa30e2103c7f8d049c0cff4cafd/ruff-0.15.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:68138a4ba184b4691ccdc39f7795c66b3c68160c586519e7e8444cf5a53e1b4c", size = 10486401, upload-time = "2026-02-12T23:09:27.927Z" }, - { url = "https://files.pythonhosted.org/packages/65/3d/e1429f64a3ff89297497916b88c32a5cc88eeca7e9c787072d0e7f1d3e1e/ruff-0.15.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:518f9af03bfc33c03bdb4cb63fabc935341bb7f54af500f92ac309ecfbba6330", size = 10197452, upload-time = "2026-02-12T23:09:12.147Z" }, - { url = "https://files.pythonhosted.org/packages/78/83/e2c3bade17dad63bf1e1c2ffaf11490603b760be149e1419b07049b36ef2/ruff-0.15.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:da79f4d6a826caaea95de0237a67e33b81e6ec2e25fc7e1993a4015dffca7c61", size = 10693900, upload-time = "2026-02-12T23:09:34.418Z" }, - { url = "https://files.pythonhosted.org/packages/a1/27/fdc0e11a813e6338e0706e8b39bb7a1d61ea5b36873b351acee7e524a72a/ruff-0.15.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3dd86dccb83cd7d4dcfac303ffc277e6048600dfc22e38158afa208e8bf94a1f", size = 11227302, upload-time = "2026-02-12T23:09:36.536Z" }, - { url = "https://files.pythonhosted.org/packages/f6/58/ac864a75067dcbd3b95be5ab4eb2b601d7fbc3d3d736a27e391a4f92a5c1/ruff-0.15.1-py3-none-win32.whl", hash = "sha256:660975d9cb49b5d5278b12b03bb9951d554543a90b74ed5d366b20e2c57c2098", size = 10462555, upload-time = "2026-02-12T23:09:29.899Z" }, - { url = "https://files.pythonhosted.org/packages/e0/5e/d4ccc8a27ecdb78116feac4935dfc39d1304536f4296168f91ed3ec00cd2/ruff-0.15.1-py3-none-win_amd64.whl", hash = "sha256:c820fef9dd5d4172a6570e5721704a96c6679b80cf7be41659ed439653f62336", size = 11599956, upload-time = "2026-02-12T23:09:01.157Z" }, - { url = "https://files.pythonhosted.org/packages/2a/07/5bda6a85b220c64c65686bc85bd0bbb23b29c62b3a9f9433fa55f17cda93/ruff-0.15.1-py3-none-win_arm64.whl", hash = "sha256:5ff7d5f0f88567850f45081fac8f4ec212be8d0b963e385c3f7d0d2eb4899416", size = 10874604, upload-time = "2026-02-12T23:09:05.515Z" }, + { url = "https://files.pythonhosted.org/packages/2f/70/3a4dc6d09b13cb3e695f28307e5d889b2e1a66b7af9c5e257e796695b0e6/ruff-0.15.2-py3-none-linux_armv6l.whl", hash = "sha256:120691a6fdae2f16d65435648160f5b81a9625288f75544dc40637436b5d3c0d", size = 10430565, upload-time = "2026-02-19T22:32:41.824Z" }, + { url = "https://files.pythonhosted.org/packages/71/0b/bb8457b56185ece1305c666dc895832946d24055be90692381c31d57466d/ruff-0.15.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a89056d831256099658b6bba4037ac6dd06f49d194199215befe2bb10457ea5e", size = 10820354, upload-time = "2026-02-19T22:32:07.366Z" }, + { url = "https://files.pythonhosted.org/packages/2d/c1/e0532d7f9c9e0b14c46f61b14afd563298b8b83f337b6789ddd987e46121/ruff-0.15.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e36dee3a64be0ebd23c86ffa3aa3fd3ac9a712ff295e192243f814a830b6bd87", size = 10170767, upload-time = "2026-02-19T22:32:13.188Z" }, + { url = "https://files.pythonhosted.org/packages/47/e8/da1aa341d3af017a21c7a62fb5ec31d4e7ad0a93ab80e3a508316efbcb23/ruff-0.15.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9fb47b6d9764677f8c0a193c0943ce9a05d6763523f132325af8a858eadc2b9", size = 10529591, upload-time = "2026-02-19T22:32:02.547Z" }, + { url = "https://files.pythonhosted.org/packages/93/74/184fbf38e9f3510231fbc5e437e808f0b48c42d1df9434b208821efcd8d6/ruff-0.15.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f376990f9d0d6442ea9014b19621d8f2aaf2b8e39fdbfc79220b7f0c596c9b80", size = 10260771, upload-time = "2026-02-19T22:32:36.938Z" }, + { url = "https://files.pythonhosted.org/packages/05/ac/605c20b8e059a0bc4b42360414baa4892ff278cec1c91fff4be0dceedefd/ruff-0.15.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2dcc987551952d73cbf5c88d9fdee815618d497e4df86cd4c4824cc59d5dd75f", size = 11045791, upload-time = "2026-02-19T22:32:31.642Z" }, + { url = "https://files.pythonhosted.org/packages/fd/52/db6e419908f45a894924d410ac77d64bdd98ff86901d833364251bd08e22/ruff-0.15.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:42a47fd785cbe8c01b9ff45031af875d101b040ad8f4de7bbb716487c74c9a77", size = 11879271, upload-time = "2026-02-19T22:32:29.305Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d8/7992b18f2008bdc9231d0f10b16df7dda964dbf639e2b8b4c1b4e91b83af/ruff-0.15.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbe9f49354866e575b4c6943856989f966421870e85cd2ac94dccb0a9dcb2fea", size = 11303707, upload-time = "2026-02-19T22:32:22.492Z" }, + { url = "https://files.pythonhosted.org/packages/d7/02/849b46184bcfdd4b64cde61752cc9a146c54759ed036edd11857e9b8443b/ruff-0.15.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7a672c82b5f9887576087d97be5ce439f04bbaf548ee987b92d3a7dede41d3a", size = 11149151, upload-time = "2026-02-19T22:32:44.234Z" }, + { url = "https://files.pythonhosted.org/packages/70/04/f5284e388bab60d1d3b99614a5a9aeb03e0f333847e2429bebd2aaa1feec/ruff-0.15.2-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:72ecc64f46f7019e2bcc3cdc05d4a7da958b629a5ab7033195e11a438403d956", size = 11091132, upload-time = "2026-02-19T22:32:24.691Z" }, + { url = "https://files.pythonhosted.org/packages/fa/ae/88d844a21110e14d92cf73d57363fab59b727ebeabe78009b9ccb23500af/ruff-0.15.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:8dcf243b15b561c655c1ef2f2b0050e5d50db37fe90115507f6ff37d865dc8b4", size = 10504717, upload-time = "2026-02-19T22:32:26.75Z" }, + { url = "https://files.pythonhosted.org/packages/64/27/867076a6ada7f2b9c8292884ab44d08fd2ba71bd2b5364d4136f3cd537e1/ruff-0.15.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dab6941c862c05739774677c6273166d2510d254dac0695c0e3f5efa1b5585de", size = 10263122, upload-time = "2026-02-19T22:32:10.036Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ef/faf9321d550f8ebf0c6373696e70d1758e20ccdc3951ad7af00c0956be7c/ruff-0.15.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b9164f57fc36058e9a6806eb92af185b0697c9fe4c7c52caa431c6554521e5c", size = 10735295, upload-time = "2026-02-19T22:32:39.227Z" }, + { url = "https://files.pythonhosted.org/packages/2f/55/e8089fec62e050ba84d71b70e7834b97709ca9b7aba10c1a0b196e493f97/ruff-0.15.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:80d24fcae24d42659db7e335b9e1531697a7102c19185b8dc4a028b952865fd8", size = 11241641, upload-time = "2026-02-19T22:32:34.617Z" }, + { url = "https://files.pythonhosted.org/packages/23/01/1c30526460f4d23222d0fabd5888868262fd0e2b71a00570ca26483cd993/ruff-0.15.2-py3-none-win32.whl", hash = "sha256:fd5ff9e5f519a7e1bd99cbe8daa324010a74f5e2ebc97c6242c08f26f3714f6f", size = 10507885, upload-time = "2026-02-19T22:32:15.635Z" }, + { url = "https://files.pythonhosted.org/packages/5c/10/3d18e3bbdf8fc50bbb4ac3cc45970aa5a9753c5cb51bf9ed9a3cd8b79fa3/ruff-0.15.2-py3-none-win_amd64.whl", hash = "sha256:d20014e3dfa400f3ff84830dfb5755ece2de45ab62ecea4af6b7262d0fb4f7c5", size = 11623725, upload-time = "2026-02-19T22:32:04.947Z" }, + { url = "https://files.pythonhosted.org/packages/6d/78/097c0798b1dab9f8affe73da9642bb4500e098cb27fd8dc9724816ac747b/ruff-0.15.2-py3-none-win_arm64.whl", hash = "sha256:cabddc5822acdc8f7b5527b36ceac55cc51eec7b1946e60181de8fe83ca8876e", size = 10941649, upload-time = "2026-02-19T22:32:18.108Z" }, ] [[package]] From 48b08be301b9a7e4dd57ea5982e2f2d64567daa7 Mon Sep 17 00:00:00 2001 From: Yannick Marcon Date: Sat, 21 Mar 2026 19:48:34 +0100 Subject: [PATCH 2/2] feat: added filter by connection names to operations (#6) * feat: added filter by connection names to operations * chore: code review --- datashield/api.py | 127 +++++++++++++++++++++---------- tests/test_session_filters.py | 136 ++++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+), 40 deletions(-) create mode 100644 tests/test_session_filters.py diff --git a/datashield/api.py b/datashield/api.py index 59ec3a0..08251ab 100644 --- a/datashield/api.py +++ b/datashield/api.py @@ -129,15 +129,19 @@ def open(self, restore: str = None, failSafe: bool = False) -> None: for name in self.errors: logging.error(f"Connection to {name} has failed") - def close(self, save: str = None) -> None: + def close(self, save: str = None, conn_names: list[str] = None) -> None: """ Close connections with remote servers. - :param cons: The list of connections to close. :param save: The name of the workspace to save before closing the connections. + :param conn_names: The optional list of connection names to close. If not defined, all opened connections are closed. """ self.errors = {} - for conn in self.conns: + if not self.conns: + return + selected_conns = self._get_selected_connections(conn_names) + selected_names = {conn.get_name() for conn in selected_conns} + for conn in selected_conns: try: if save: conn.save_workspace(f"{conn.get_name()}:{save}") @@ -145,7 +149,10 @@ def close(self, save: str = None) -> None: except DSError: # silently fail pass - self.conns = None + if conn_names is None: + self.conns = None + else: + self.conns = [conn for conn in self.conns if conn.get_name() not in selected_names] def has_connections(self) -> bool: """ @@ -153,7 +160,7 @@ def has_connections(self) -> bool: :return: True if some connections were opened, False otherwise """ - return len(self.conns) > 0 + return self.conns and len(self.conns) > 0 def get_connection_names(self) -> list[str]: """ @@ -186,27 +193,29 @@ def get_errors(self) -> dict: # Environment # - def tables(self) -> dict: + def tables(self, conn_names: list[str] = None) -> dict: """ List available table names from the data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available table names from the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_tables() return rval - def variables(self, table: str = None, tables: dict = None) -> dict: + def variables(self, table: str = None, tables: dict = None, conn_names: list[str] = None) -> dict: """ List available variables from the data repository, for a given table. :param table: The default name of the table to list variables for :param tables: The name of the table to list variables for, per server name. If not defined, 'table' is used. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available variables from the data repository, for a given table, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): name = table if tables and conn.get_name() in tables: name = tables[conn.get_name()] @@ -216,75 +225,81 @@ def variables(self, table: str = None, tables: dict = None) -> dict: rval[conn.get_name()] = None return rval - def taxonomies(self) -> dict: + def taxonomies(self, conn_names: list[str] = None) -> dict: """ List available taxonomies from the data repository. A taxonomy is a hierarchical structure of vocabulary terms that can be used to annotate variables in the data repository. Depending on the data repository's capabilities, taxonomies can be used to perform structured queries when searching for variables. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available taxonomies from the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_taxonomies() return rval - def search_variables(self, query: str) -> dict: + def search_variables(self, query: str, conn_names: list[str] = None) -> dict: """ Search for variable names matching a given query across all tables in the data repository. :param query: The query to search for in variable names, e.g., a full-text search and/or structured query (based on taxonomy terms), depending on the data repository's capabilities + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The matching variable names from the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.search_variables(query) return rval - def resources(self) -> dict: + def resources(self, conn_names: list[str] = None) -> dict: """ List available resource names from the data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available resource names from the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_resources() return rval - def profiles(self) -> dict: + def profiles(self, conn_names: list[str] = None) -> dict: """ List available DataSHIELD profile names in the data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available DataSHIELD profile names in the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_profiles() return rval - def packages(self) -> dict: + def packages(self, conn_names: list[str] = None) -> dict: """ Get the list of DataSHIELD packages with their version, that have been configured on the remote data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The list of DataSHIELD packages with their version, that have been configured on the remote data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_packages() return rval - def methods(self, type: str = "aggregate") -> dict: + def methods(self, type: str = "aggregate", conn_names: list[str] = None) -> dict: """ Get the list of DataSHIELD methods that have been configured on the remote data repository. :param type: The type of method, either "aggregate" (default) or "assign" + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The list of DataSHIELD methods that have been configured on the remote data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_methods(type) return rval @@ -292,44 +307,48 @@ def methods(self, type: str = "aggregate") -> dict: # Workspaces # - def workspaces(self) -> dict: + def workspaces(self, conn_names: list[str] = None) -> dict: """ Get the list of DataSHIELD workspaces, that have been saved on the remote data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The list of DataSHIELD workspaces, that have been saved on the remote data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_workspaces() return rval - def workspace_save(self, name: str) -> None: + def workspace_save(self, name: str, conn_names: list[str] = None) -> None: """ Save the DataSHIELD R session in a workspace on the remote data repository. :param name: The name of the workspace + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): conn.save_workspace(f"{conn.get_name()}:{name}") - def workspace_restore(self, name: str) -> None: + def workspace_restore(self, name: str, conn_names: list[str] = None) -> None: """ Restore a saved DataSHIELD R session from the remote data repository. When restoring a workspace, any existing symbol or file with same name will be overridden. :param name: The name of the workspace + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): conn.restore_workspace(f"{conn.get_name()}:{name}") - def workspace_rm(self, name: str) -> None: + def workspace_rm(self, name: str, conn_names: list[str] = None) -> None: """ Remove a DataSHIELD workspace from the remote data repository. Ignored if no such workspace exists. :param name: The name of the workspace + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): conn.rm_workspace(f"{conn.get_name()}:{name}") # @@ -358,6 +377,9 @@ def sessions(self) -> dict: """ rval = {} self._init_errors() + if not self.conns or len(self.conns) == 0: + return rval + started_conns = [] excluded_conns = [] @@ -409,7 +431,7 @@ def sessions(self) -> dict: raise DSError("No sessions could be started successfully.") return rval - def ls(self) -> dict: + def ls(self, conn_names: list[str] = None) -> dict: """ After assignments have been performed, list the symbols that live in the DataSHIELD R session on the server side. @@ -418,7 +440,7 @@ def ls(self) -> dict: # ensure sessions are started and available self.sessions() rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): try: rval[conn.get_name()] = conn.list_symbols() except Exception as e: @@ -427,15 +449,16 @@ def ls(self) -> dict: self._check_errors() return rval - def rm(self, symbol: str) -> None: + def rm(self, symbol: str, conn_names: list[str] = None) -> None: """ Remove a symbol from remote servers. :param symbol: The name of the symbol to remove + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ # ensure sessions are started and available self.sessions() - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): try: conn.rm_symbol(symbol) except Exception as e: @@ -452,6 +475,7 @@ def assign_table( identifiers: str = None, id_name: str = None, asynchronous: bool = True, + conn_names: list[str] = None, ) -> None: """ Assign a data table from the data repository to a symbol in the DataSHIELD R session. @@ -460,11 +484,12 @@ def assign_table( :param table: The default name of the table to assign :param tables: The name of the table to assign, per server name. If not defined, 'table' is used. :param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server) + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ # ensure sessions are started and available self.sessions() cmd = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): name = table if tables and conn.get_name() in tables: name = tables[conn.get_name()] @@ -478,7 +503,12 @@ def assign_table( self._check_errors() def assign_resource( - self, symbol: str, resource: str = None, resources: dict = None, asynchronous: bool = True + self, + symbol: str, + resource: str = None, + resources: dict = None, + asynchronous: bool = True, + conn_names: list[str] = None, ) -> None: """ Assign a resource from the data repository to a symbol in the DataSHIELD R session. @@ -487,11 +517,12 @@ def assign_resource( :param resource: The default name of the resource to assign :param resources: The name of the resource to assign, per server name. If not defined, 'resource' is used. :param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server) + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ # ensure sessions are started and available self.sessions() cmd = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): name = resource if resources and conn.get_name() in resources: name = resources[conn.get_name()] @@ -504,18 +535,19 @@ def assign_resource( self._do_wait(cmd) self._check_errors() - def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None: + def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> None: """ Assign the result of the evaluation of an expression to a symbol in the DataSHIELD R session. :param symbol: The name of the destination symbol :param expr: The R expression to evaluate and which result will be assigned :param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server) + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ # ensure sessions are started and available self.sessions() cmd = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): try: res = conn.assign_expr(symbol, expr, asynchronous) cmd[conn.get_name()] = res @@ -524,20 +556,21 @@ def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None self._do_wait(cmd) self._check_errors() - def aggregate(self, expr: str, asynchronous: bool = True) -> dict: + def aggregate(self, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> dict: """ Aggregate some data from the DataSHIELD R session using a valid R expression. The aggregation expression must satisfy the data repository's DataSHIELD configuration. :param expr: The R expression to evaluate and which result will be returned :param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server) + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. :return: The result of the aggregation expression evaluation, per remote server name """ # ensure sessions are started and available self.sessions() cmd = {} rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): try: res = conn.aggregate(expr, asynchronous) cmd[conn.get_name()] = res @@ -573,6 +606,20 @@ def _do_wait(self, cmd: dict) -> dict: time.sleep(0.1) return rval + def _get_selected_connections(self, conn_names: list[str] = None) -> list[DSConnection]: + """ + Get the list of opened connections, optionally filtered by connection names. + + :param conn_names: The optional list of connection names to select. + :return: The list of selected opened connections + """ + if not self.conns: + return [] + if conn_names is None: + return self.conns + selected_names = set(conn_names) + return [conn for conn in self.conns if conn.get_name() in selected_names] + def _init_errors(self) -> None: """ Prepare for storing errors. diff --git a/tests/test_session_filters.py b/tests/test_session_filters.py new file mode 100644 index 0000000..be71e4c --- /dev/null +++ b/tests/test_session_filters.py @@ -0,0 +1,136 @@ +from datashield import DSSession + + +class FakeResult: + def __init__(self, value): + self.value = value + + def is_completed(self) -> bool: + return True + + def fetch(self): + return self.value + + +class FakeConn: + def __init__(self, name: str): + self._name = name + self.started = False + self.disconnected = False + self.saved_workspaces = [] + self.restored_workspaces = [] + self.removed_workspaces = [] + self.rm_symbols = [] + self.assign_expr_calls = [] + self.keep_alive_calls = 0 + + def get_name(self) -> str: + return self._name + + def list_tables(self) -> list: + return [f"{self._name}_table"] + + def has_session(self) -> bool: + return self.started + + def start_session(self, asynchronous: bool = True): + self.started = True + return {"started": True, "async": asynchronous} + + def is_session_started(self) -> bool: + return self.started + + def get_session(self): + return {"name": self._name} + + def list_symbols(self) -> list: + return [f"{self._name}_symbol"] + + def rm_symbol(self, name: str) -> None: + self.rm_symbols.append(name) + + def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> FakeResult: + self.assign_expr_calls.append((symbol, expr, asynchronous)) + return FakeResult({"symbol": symbol, "expr": expr, "conn": self._name}) + + def aggregate(self, expr: str, asynchronous: bool = True) -> FakeResult: + return FakeResult({"expr": expr, "conn": self._name, "async": asynchronous}) + + def save_workspace(self, name: str) -> list: + self.saved_workspaces.append(name) + return self.saved_workspaces + + def restore_workspace(self, name: str) -> list: + self.restored_workspaces.append(name) + return self.restored_workspaces + + def rm_workspace(self, name: str) -> list: + self.removed_workspaces.append(name) + return self.removed_workspaces + + def keep_alive(self) -> None: + self.keep_alive_calls += 1 + + def disconnect(self) -> None: + self.disconnected = True + + +def make_session() -> tuple[DSSession, FakeConn, FakeConn]: + conn1 = FakeConn("server1") + conn2 = FakeConn("server2") + session = DSSession([]) + session.conns = [conn1, conn2] + session.errors = {} + return session, conn1, conn2 + + +def test_tables_filters_connections(): + session, _, _ = make_session() + + result = session.tables(conn_names=["server2", "unknown"]) + + assert result == {"server2": ["server2_table"]} + + +def test_assign_expr_filters_connections(): + session, conn1, conn2 = make_session() + + session.assign_expr("x", "1+1", conn_names=["server1", "unknown"]) + + assert conn1.assign_expr_calls == [("x", "1+1", True)] + assert conn2.assign_expr_calls == [] + + +def test_aggregate_filters_connections(): + session, conn1, conn2 = make_session() + + result = session.aggregate("2+2", conn_names=["server2"]) + + assert result == {"server2": {"expr": "2+2", "conn": "server2", "async": True}} + assert conn1.assign_expr_calls == [] + assert conn2.assign_expr_calls == [] + + +def test_workspace_methods_filter_connections(): + session, conn1, conn2 = make_session() + + session.workspace_save("wk", conn_names=["server1"]) + session.workspace_restore("wk", conn_names=["server2"]) + session.workspace_rm("wk", conn_names=["server2", "missing"]) + + assert conn1.saved_workspaces == ["server1:wk"] + assert conn2.saved_workspaces == [] + assert conn1.restored_workspaces == [] + assert conn2.restored_workspaces == ["server2:wk"] + assert conn1.removed_workspaces == [] + assert conn2.removed_workspaces == ["server2:wk"] + + +def test_close_filters_connections_and_keeps_others_open(): + session, conn1, conn2 = make_session() + + session.close(conn_names=["server1", "unknown"]) + + assert conn1.disconnected is True + assert conn2.disconnected is False + assert session.get_connection_names() == ["server2"]