diff --git a/docs/source/conf.py b/docs/source/conf.py index 2486799..ded3bcd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -7,10 +7,10 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'lyscripts' -copyright = '2024, Roman Ludwig' -author = 'Roman Ludwig' -gh_username = 'rmnldwg' +project = "lyscripts" +copyright = "2024, Roman Ludwig" +author = "Roman Ludwig" +gh_username = "rmnldwg" version = lyscripts.__version__ release = lyscripts.__version__ @@ -18,36 +18,36 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ - 'sphinx.ext.intersphinx', - 'sphinx.ext.autodoc', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinxcontrib.programoutput', - 'sphinx.ext.napoleon', - 'm2r', + "sphinx.ext.intersphinx", + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinxcontrib.programoutput", + "sphinx.ext.napoleon", + "m2r", ] # markdown to reST -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] -templates_path = ['_templates'] +templates_path = ["_templates"] exclude_patterns = [] # document classes and their constructors -autoclass_content = 'class' +autoclass_content = "class" # sort members by source -autodoc_member_order = 'bysource' +autodoc_member_order = "bysource" # show type hints -autodoc_typehints = 'signature' +autodoc_typehints = "signature" # create links to other projects intersphinx_mapping = { - 'python': ('https://docs.python.org/3.10', None), - 'lymph': ('https://lymph-model.readthedocs.io/en/latest/', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), - 'numpy': ('https://numpy.org/doc/stable/', None), + "python": ("https://docs.python.org/3.10", None), + "lymph": ("https://lymph-model.readthedocs.io/en/latest/", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "numpy": ("https://numpy.org/doc/stable/", None), } @@ -57,7 +57,7 @@ # a list of builtin themes. # -html_theme = 'sphinx_book_theme' +html_theme = "sphinx_book_theme" html_theme_options = { "repository_url": f"https://github.com/{gh_username}/{project}", "repository_branch": "main", @@ -76,7 +76,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['./_static'] +html_static_path = ["./_static"] html_css_files = [ "css/custom.css", ] diff --git a/lyscripts/__init__.py b/lyscripts/__init__.py index 0fe8e89..53e9276 100644 --- a/lyscripts/__init__.py +++ b/lyscripts/__init__.py @@ -1,10 +1,10 @@ -""" -This is the top-level module of the `lyscripts` package. It contains the +"""This is the top-level module of the `lyscripts` package. It contains the :py:func:`.main` function that is used to start the command line interface (CLI) for the package. Also, it configures the logging system and sets the metadate of the package. """ + import argparse import logging import re @@ -13,7 +13,16 @@ import rich from rich_argparse import RichHelpFormatter -from lyscripts import app, compute, data, evaluate, plot, sample, temp_schedule +from lyscripts import ( + app, + compute, + data, + evaluate, + mixture_fit, + mixture_sample, + plot, + temp_schedule, +) from lyscripts._version import version from lyscripts.utils import CustomRichHandler, console @@ -41,6 +50,7 @@ class RichDefaultHelpFormatter( .. _rich: https://rich.readthedocs.io/en/stable/introduction.html """ + def _rich_fill_text( self, text: rich.text.Text, @@ -67,17 +77,11 @@ def _rich_fill_text( RichDefaultHelpFormatter.styles["argparse.syntax"] = "red" RichDefaultHelpFormatter.styles["argparse.formula"] = "green" -RichDefaultHelpFormatter.highlights.append( - r"\$(?P[^$]*)\$" -) +RichDefaultHelpFormatter.highlights.append(r"\$(?P[^$]*)\$") RichDefaultHelpFormatter.styles["argparse.bold"] = "bold" -RichDefaultHelpFormatter.highlights.append( - r"\*(?P[^*]*)\*" -) +RichDefaultHelpFormatter.highlights.append(r"\*(?P[^*]*)\*") RichDefaultHelpFormatter.styles["argparse.italic"] = "italic" -RichDefaultHelpFormatter.highlights.append( - r"_(?P[^_]*)_" -) +RichDefaultHelpFormatter.highlights.append(r"_(?P[^_]*)_") def exit_cli(args: argparse.Namespace): @@ -97,11 +101,11 @@ def main(): ) parser.set_defaults(run_main=exit_cli) parser.add_argument( - "-v", "--version", action="store_true", - help="Display the version of lyscripts" + "-v", "--version", action="store_true", help="Display the version of lyscripts" ) parser.add_argument( - "--log-level", default="INFO", + "--log-level", + default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], ) @@ -114,8 +118,9 @@ def main(): data._add_parser(subparsers, help_formatter=parser.formatter_class) evaluate._add_parser(subparsers, help_formatter=parser.formatter_class) plot._add_parser(subparsers, help_formatter=parser.formatter_class) - sample._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_fit._add_parser(subparsers, help_formatter=parser.formatter_class) temp_schedule._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_sample._add_parser(subparsers, help_formatter=parser.formatter_class) args = parser.parse_args() diff --git a/lyscripts/__main__.py b/lyscripts/__main__.py index d9ba03c..1bc2e10 100644 --- a/lyscripts/__main__.py +++ b/lyscripts/__main__.py @@ -1,7 +1,7 @@ -""" -Utility for performing common tasks w.r.t. the inference and prediction tasks one +"""Utility for performing common tasks w.r.t. the inference and prediction tasks one can use the `lymph` package for. """ + from lyscripts import main # I need another __main__ guard here, because otherwise pdoc tries to run this diff --git a/lyscripts/app/__init__.py b/lyscripts/app/__init__.py index 7c3db75..3ed505e 100644 --- a/lyscripts/app/__init__.py +++ b/lyscripts/app/__init__.py @@ -1,8 +1,8 @@ -""" -Module containing scripts to run different `streamlit`_ applications. +"""Module containing scripts to run different `streamlit`_ applications. .. _streamlit: https://streamlit.io/ """ + import argparse from pathlib import Path diff --git a/lyscripts/app/prevalence.py b/lyscripts/app/prevalence.py index 2ee0c4d..83a9b0c 100644 --- a/lyscripts/app/prevalence.py +++ b/lyscripts/app/prevalence.py @@ -1,5 +1,4 @@ -""" -A `streamlit`_ app for computing, displaying, and reproducing prevalence estimates. +"""A `streamlit`_ app for computing, displaying, and reproducing prevalence estimates. The primary goal with this little GUI is that one can quickly draft some data & prediction comparisons visually and then copy & paste the configuration in YAML format @@ -7,6 +6,7 @@ .. _streamlit: https://streamlit.io/ """ + import argparse import sys from pathlib import Path @@ -51,10 +51,7 @@ def _add_arguments(parser: argparse.ArgumentParser): .. _streamlit: https://streamlit.io/ """ - parser.add_argument( - "--message", type=str, - help="Print our this little message." - ) + parser.add_argument("--message", type=str, help="Print our this little message.") parser.set_defaults(run_main=launch_streamlit) @@ -122,7 +119,7 @@ def interactive_load(streamlit): type=["csv"], help="CSV spreadsheet containing lymphatic patterns of progression", ) - header_rows = [0,1] if is_unilateral else [0,1,2] + header_rows = [0, 1] if is_unilateral else [0, 1, 2] patient_data = load_patient_data(data_file, header=header_rows) streamlit.write("---") @@ -130,7 +127,7 @@ def interactive_load(streamlit): samples_file = streamlit.file_uploader( label="HDF5 sample file", type=["hdf5", "hdf", "h5"], - help="HDF5 file containing the samples." + help="HDF5 file containing the samples.", ) samples = load_model_samples(samples_file) @@ -138,10 +135,7 @@ def interactive_load(streamlit): def interactive_pattern( - streamlit, - is_unilateral: bool, - lnls: list[str], - side: str + streamlit, is_unilateral: bool, lnls: list[str], side: str ) -> dict[str, bool]: """Create a `streamlit`_ panel for specifying an involvement pattern. @@ -179,7 +173,7 @@ def interactive_additional_params( The respective controls are presented next to each other in three dedicated columns. """ - control_cols = streamlit.columns([1,2,1,1,1]) + control_cols = streamlit.columns([1, 2, 1, 1, 1]) t_stage = control_cols[0].selectbox( label="T-category", options=model.diag_time_dists.keys(), @@ -270,15 +264,16 @@ def add_current_scenario( session_state["contents"].append(beta_posterior) session_state["contents"].append(histogram) - session_state["scenarios"].append({ - "pattern": reduce_pattern(pattern), **prevs_kwargs - }) + session_state["scenarios"].append( + {"pattern": reduce_pattern(pattern), **prevs_kwargs} + ) def main(args: argparse.Namespace): """The main function that contains the `streamlit`_ code and functionality. - .. _streamlit: https://streamlit.io/""" + .. _streamlit: https://streamlit.io/ + """ import streamlit as st st.title("Prevalence") @@ -335,7 +330,7 @@ def main(args: argparse.Namespace): ) fig, ax = plt.subplots() - draw(axes=ax, contents=st.session_state.get("contents", []), xlims=(0., 100.)) + draw(axes=ax, contents=st.session_state.get("contents", []), xlims=(0.0, 100.0)) ax.legend() st.pyplot(fig) diff --git a/lyscripts/compute/__init__.py b/lyscripts/compute/__init__.py index ccc12e9..29a678f 100644 --- a/lyscripts/compute/__init__.py +++ b/lyscripts/compute/__init__.py @@ -1,8 +1,8 @@ -""" -With the commands of this module, a user may compute prior and posterior state +"""With the commands of this module, a user may compute prior and posterior state distributions from drawn samples of a model. This can in turn speed up the computation of risks and prevalences. """ + import argparse from pathlib import Path diff --git a/lyscripts/compute/prevalences.py b/lyscripts/compute/prevalences.py index 913b26c..c6179ac 100644 --- a/lyscripts/compute/prevalences.py +++ b/lyscripts/compute/prevalences.py @@ -15,10 +15,13 @@ observed involvement pattern. Warning: +------- The command skips the computation of the priors if it finds them in the cache. But this cache only accounts for the scenario, *NOT* the samples. So, if the samples change, you need to force a recomputation of the priors (e.g., by deleting them). + """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -35,7 +38,6 @@ from lyscripts import utils from lyscripts.compute.priors import compute_priors_using_cache from lyscripts.compute.utils import HDF5FileCache, get_modality_subset -from lyscripts.data import accessor # nopycln: import from lyscripts.scenario import Scenario, add_scenario_arguments logger = logging.getLogger(__name__) @@ -58,41 +60,45 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments needed to run this script to a ``subparsers`` instance.""" parser.add_argument( - "--priors", type=Path, required=True, + "--priors", + type=Path, + required=True, help=( "Path to the prior state distributions (HDF5 file). If samples are " "provided, this will be used as output to store the computed posteriors. " "If no samples are provided, this will be used as input to load the priors." - ) + ), ) parser.add_argument( - "--prevalences", type=Path, required=True, - help="Path to the HDF5 file for storing the computed prevalences." + "--prevalences", + type=Path, + required=True, + help="Path to the HDF5 file for storing the computed prevalences.", ) parser.add_argument( - "--data", type=Path, required=False, - help="Path to the patient data (CSV file)." + "--data", type=Path, required=False, help="Path to the patient data (CSV file)." ) parser.add_argument( - "--params", default="./params.yaml", type=Path, - help="Path to parameter file defining the model (YAML)." + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file defining the model (YAML).", ) parser.add_argument( - "--scenarios", type=Path, required=False, + "--scenarios", + type=Path, + required=False, help=( "Path to a YAML file containing a `scenarios` key with a list of " "involvement scenarios to compute the posteriors for." - ) + ), ) add_scenario_arguments(parser, for_comp="prevalences") parser.set_defaults(run_main=main) -def does_midext_match( - data: pd.DataFrame, - midext: bool | None = None -) -> pd.Index: +def does_midext_match(data: pd.DataFrame, midext: bool | None = None) -> pd.Index: """Return indices of ``data`` where ``midline_ext`` of the patients matches.""" midext_col = data["tumor", "1", "extension"] if midext is None: @@ -112,8 +118,10 @@ def compute_observed_prevalence( T-stages defined in the ``scenario``. Warning: + ------- When computing prevalences for unilateral models, the contralateral diagnosis will still be considered for computing the prevalence in the *data*. + """ # when looking at the data, we always consider both sides is_uni = scenario.is_uni @@ -147,7 +155,7 @@ def observe_prevalence_using_cache( data: pd.DataFrame, scenario: Scenario, cache: HDF5FileCache, - mapping: dict[int, Any] | Callable[[int], Any] = None + mapping: dict[int, Any] | Callable[[int], Any] = None, ): """Compute and cache the observed prevalence for a given ``scenario``.""" num_match, num_total = compute_observed_prevalence( @@ -239,12 +247,14 @@ def main(args: argparse.Namespace): if args.scenarios is None: # create a single scenario from the stdin arguments... - scenarios = [Scenario.from_namespace( - namespace=args, - lnls=lnls, - is_uni=isinstance(model, models.Unilateral), - side=params["model"].get("side", "ipsi"), - )] + scenarios = [ + Scenario.from_namespace( + namespace=args, + lnls=lnls, + is_uni=isinstance(model, models.Unilateral), + side=params["model"].get("side", "ipsi"), + ) + ] num_scens = len(scenarios) else: # ...or load the scenarios from a YAML file diff --git a/lyscripts/compute/priors.py b/lyscripts/compute/priors.py index 4f8283e..59e98be 100644 --- a/lyscripts/compute/priors.py +++ b/lyscripts/compute/priors.py @@ -1,5 +1,4 @@ -""" -Given samples drawn during an MCMC round, compute the (prior) state distribution for +"""Given samples drawn during an MCMC round, compute the (prior) state distribution for each sample. This may then later on be used to compute risks and prevalences more quickly. @@ -8,6 +7,7 @@ distribution that was used to marginalize over them, as well as the model's computation mode (hidden Markov model or Bayesian network). """ + import argparse import logging from pathlib import Path @@ -40,23 +40,31 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments needed to run this script to a ``subparsers`` instance.""" parser.add_argument( - "--samples", type=Path, required=True, - help="Path to the drawn samples (HDF5 file)." + "--samples", + type=Path, + required=True, + help="Path to the drawn samples (HDF5 file).", ) parser.add_argument( - "--priors", type=Path, required=True, - help="Path to file for storing the computed prior distributions." + "--priors", + type=Path, + required=True, + help="Path to file for storing the computed prior distributions.", ) parser.add_argument( - "--params", type=Path, required=True, - help="Path to parameter file defining the model (YAML)." + "--params", + type=Path, + required=True, + help="Path to parameter file defining the model (YAML).", ) parser.add_argument( - "--scenarios", type=Path, required=False, + "--scenarios", + type=Path, + required=False, help=( "Path to a YAML file containing a `scenarios` key with a list of " "diagnosis scenarios to compute the posteriors for." - ) + ), ) add_scenario_arguments(parser, for_comp="priors") @@ -98,10 +106,12 @@ def compute_priors_using_cache( total=len(samples), ): model.set_params(*sample) - priors.append(sum( - model.state_dist(t_stage=t, mode=scenario.mode) * p - for t, p in zip(scenario.t_stages, scenario.t_stages_dist) - )) + priors.append( + sum( + model.state_dist(t_stage=t, mode=scenario.mode) * p + for t, p in zip(scenario.t_stages, scenario.t_stages_dist, strict=False) + ) + ) priors = np.stack(priors) cache[priors_hash] = (priors, scenario.as_dict("priors")) @@ -109,7 +119,7 @@ def compute_priors_using_cache( def main(args: argparse.Namespace): - """compute the prior state distribution for each sample.""" + """Compute the prior state distribution for each sample.""" params = utils.load_yaml_params(args.params) if args.scenarios is None: diff --git a/lyscripts/compute/risks.py b/lyscripts/compute/risks.py index 4e6cf54..d5b4cce 100644 --- a/lyscripts/compute/risks.py +++ b/lyscripts/compute/risks.py @@ -1,10 +1,10 @@ -""" -Predict risks of involvements using the posteriors that were computed using the +"""Predict risks of involvements using the posteriors that were computed using the :py:mod:`.compute.posteriors` command. The structure of these scenarios is similar to how scenarios are defined for the :py:mod:`.compute.prevalences` script. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -39,24 +39,32 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments needed to run this script to a `subparsers` instance.""" parser.add_argument( - "--posteriors", type=Path, required=True, - help="Path to the computed posteriors (HDF5 file)." + "--posteriors", + type=Path, + required=True, + help="Path to the computed posteriors (HDF5 file).", ) parser.add_argument( - "--risks", type=Path, required=True, - help="Path to file for storing the computed risks." + "--risks", + type=Path, + required=True, + help="Path to file for storing the computed risks.", ) parser.add_argument( - "--params", type=Path, required=True, - help="Path to parameter file defining the model (YAML)." + "--params", + type=Path, + required=True, + help="Path to parameter file defining the model (YAML).", ) parser.add_argument( - "--scenarios", type=Path, required=False, + "--scenarios", + type=Path, + required=False, help=( "Path to a YAML file containing a `scenarios` key with a list of " "diagnosis scenarios and involvement patterns of interest to compute the " "risks for." - ) + ), ) add_scenario_arguments(parser, for_comp="risks") @@ -100,11 +108,13 @@ def compute_risks_using_cache( description="[blue]INFO [/blue]" + progress_desc, total=len(posteriors), ): - risks.append(model.marginalize( - involvement=scenario.involvement, - given_state_dist=posterior, - **kwargs, - )) + risks.append( + model.marginalize( + involvement=scenario.involvement, + given_state_dist=posterior, + **kwargs, + ) + ) risks = np.stack(risks) risks_cache[risks_hash] = (risks, scenario.as_dict("risks")) @@ -119,12 +129,14 @@ def main(args: argparse.Namespace): if args.scenarios is None: # create a single scenario from the stdin arguments... - scenarios = [Scenario.from_namespace( - namespace=args, - lnls=lnls, - is_uni=isinstance(model, models.Unilateral), - side=params["model"].get("side", "ipsi"), - )] + scenarios = [ + Scenario.from_namespace( + namespace=args, + lnls=lnls, + is_uni=isinstance(model, models.Unilateral), + side=params["model"].get("side", "ipsi"), + ) + ] num_scens = len(scenarios) else: # ...or load the scenarios from a YAML file diff --git a/lyscripts/compute/utils.py b/lyscripts/compute/utils.py index c50e1c8..09c5e00 100644 --- a/lyscripts/compute/utils.py +++ b/lyscripts/compute/utils.py @@ -1,6 +1,5 @@ -""" -Utilities for precomputing the priors and posteriors. -""" +"""Utilities for precomputing the priors and posteriors.""" + from pathlib import Path from typing import Any @@ -51,6 +50,7 @@ def get_modality_subset(diagnosis: dict[str, Any]) -> set[str]: class HDF5FileCache: """HDF5 file acting as a cache for expensive arrays.""" + def __init__(self, file_path: Path) -> None: """Initialize the cache with the given ``file_path``.""" file_path.parent.mkdir(parents=True, exist_ok=True) @@ -86,13 +86,14 @@ def reduce_pattern(pattern: dict[str, dict[str, bool]]) -> dict[str, dict[str, b but be shorter to store. Example: - + ------- >>> full = { ... "ipsi": {"I": None, "II": True, "III": None}, ... "contra": {"I": None, "II": None, "III": None}, ... } >>> reduce_pattern(full) {'ipsi': {'II': True}} + """ tmp_pattern = pattern.copy() reduced_pattern = {} @@ -116,10 +117,12 @@ def complete_pattern( contain ``True``, ``False`` or ``None``. Example: + ------- >>> pattern = {"ipsi": {"II": True}} >>> lnls = ["II", "III"] >>> complete_pattern(pattern, lnls) {'ipsi': {'II': True, 'III': None}, 'contra': {'II': None, 'III': None}} + """ if pattern is None: pattern = {} diff --git a/lyscripts/data/__init__.py b/lyscripts/data/__init__.py index 00754db..bf0c240 100644 --- a/lyscripts/data/__init__.py +++ b/lyscripts/data/__init__.py @@ -1,5 +1,4 @@ -""" -Provide a range of commands and functions related to managing CSV datasets on patterns +"""Provide a range of commands and functions related to managing CSV datasets on patterns of lymphatic progression. It helps transform raw CSV data of any form to be converted into our `LyProX`_ format, @@ -7,10 +6,11 @@ .. _LyProX: https://lyprox.org """ + import argparse from pathlib import Path -from lyscripts.data import enhance, filter, generate, join, lyproxify, split +from lyscripts.data import bootstrap, enhance, filter, generate, join, lyproxify, split def _add_parser( @@ -31,3 +31,4 @@ def _add_parser( lyproxify._add_parser(subparsers, help_formatter=parser.formatter_class) split._add_parser(subparsers, help_formatter=parser.formatter_class) filter._add_parser(subparsers, help_formatter=parser.formatter_class) + bootstrap._add_parser(subparsers, help_formatter=parser.formatter_class) diff --git a/lyscripts/data/__main__.py b/lyscripts/data/__main__.py index 0c6f67c..a540dda 100644 --- a/lyscripts/data/__main__.py +++ b/lyscripts/data/__main__.py @@ -1,7 +1,7 @@ import argparse from lyscripts import RichDefaultHelpFormatter, exit_cli -from lyscripts.data import enhance, filter, generate, join, split +from lyscripts.data import bootstrap, enhance, filter, generate, join, split # I need another __main__ guard here, because otherwise pdoc tries to run this if __name__ == "__main__": @@ -20,6 +20,7 @@ join._add_parser(subparsers, help_formatter=parser.formatter_class) split._add_parser(subparsers, help_formatter=parser.formatter_class) filter._add_parser(subparsers, help_formatter=parser.formatter_class) + bootstrap._add_parser(subparsers, help_formatter=parser.formatter_class) args = parser.parse_args() args.run_main(args) diff --git a/lyscripts/data/accessor.py b/lyscripts/data/accessor.py index 40eeb56..1a04cbb 100644 --- a/lyscripts/data/accessor.py +++ b/lyscripts/data/accessor.py @@ -1,8 +1,8 @@ -""" -Create a custom pandas accessor to handle `LyProX`_ style data. +"""Create a custom pandas accessor to handle `LyProX`_ style data. .. _LyProX: https://lyprox.org """ + from collections.abc import Callable from typing import Any @@ -50,6 +50,7 @@ class LyProXAccessor: .. _LyProX: https://lyprox.org """ + def __init__(self, obj: pd.DataFrame) -> None: self._validate(obj) self._obj = obj diff --git a/lyscripts/data/bootstrap.py b/lyscripts/data/bootstrap.py new file mode 100644 index 0000000..c4c6907 --- /dev/null +++ b/lyscripts/data/bootstrap.py @@ -0,0 +1,114 @@ +"""Learn the spread probabilities of the HMM for lymphatic tumor progression using +the preprocessed data as input and the mixture model. +""" + +# pylint: disable=logging-fstring-interpolation +import argparse +import logging +import os +from pathlib import Path + +import numpy as np +import pandas as pd +from sklearn.utils import resample + +from lyscripts.utils import ( + load_patient_data, + load_yaml_params, +) + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add a parser to the ``subparsers`` action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument("--input", type=Path, help="Path to a LyProX-style CSV file") + parser.add_argument( + "--output", type=Path, help="Folder destination to save LyProX-style CSV files" + ) + parser.add_argument( + "-p", + "--params", + default="params.yaml", + type=Path, + help="Path to parameter file", + ) + + parser.set_defaults(run_main=main) + + +def proportional_bootstrap(df, group_col, n_bootstraps, folder_path=None): + """Produce n_bootstraps bootstrapped datasets from the original DataFrame. + this keeps the number of patients per subsite constant. + + Args: + ---- + df (pd.DataFrame): The original DataFrame to bootstrap from. + group_col (str): The name of the column to group by. + n_bootstraps (int): The number of bootstrapped datasets to create. + + Returns: + ------- + list[pd.DataFrame]: A list of bootstrapped DataFrames. + + """ + datasets = [] + group_sizes = df[group_col].value_counts(normalize=True) + total_n = len(df) + os.makedirs(folder_path, exist_ok=True) + for _ in range(n_bootstraps): + samples = [] + for group, proportion in group_sizes.items(): + n_samples = int(np.round(proportion * total_n)) + group_df = df[df[group_col] == group] + boot_group = resample(group_df, replace=True, n_samples=n_samples) + samples.append(boot_group) + boot_df = ( + pd.concat(samples).sample(frac=1).reset_index(drop=True) + ) # optional shuffle + datasets.append(boot_df) + if folder_path is not None: + for i, dataset in enumerate(datasets): + file_path = os.path.join(folder_path, f"dataset_resample_{i}.csv") + dataset.to_csv(file_path, index=False) + else: + return datasets + + +def main(args: argparse.Namespace) -> None: + """Main function to sample parameters for a mixture model""" + input_table = load_patient_data(args.input) + params = load_yaml_params(args.params) + if ("tumor", "core", "subsite") in input_table.columns: + group_col = ("tumor", "core", "subsite") + elif ("tumor", "1", "subsite") in input_table.columns: + group_col = ("tumor", "1", "subsite") + else: + logger.error("No 'subsite' column found in the input data.") + proportional_bootstrap( + input_table, + group_col=group_col, + n_bootstraps=params["sampling"]["n_bootstraps"], + folder_path=args.output, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/data/enhance.py b/lyscripts/data/enhance.py index 0abd858..03dbfd4 100644 --- a/lyscripts/data/enhance.py +++ b/lyscripts/data/enhance.py @@ -1,5 +1,4 @@ -""" -Enhance a LyProX-style CSV dataset in two ways: +"""Enhance a LyProX-style CSV dataset in two ways: 1. Add consensus diagnosis based on all available modalities using on of two methods: ``max_llh`` infers the most likely true state of involvement given only the available @@ -11,6 +10,7 @@ correct values. Conversely, if e.g. LNL II is reported to be healthy, we can assume the sublevels IIa and IIb would have been reported as healthy, too. """ + # pylint: disable=singleton-comparison,logging-fstring-interpolation import argparse import logging @@ -51,35 +51,44 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" + parser.add_argument("input", type=Path, help="Path to a LyProX-style CSV file") parser.add_argument( - "input", type=Path, - help="Path to a LyProX-style CSV file" - ) - parser.add_argument( - "output", type=Path, - help="Destination for LyProX-style output file including the consensus" + "output", + type=Path, + help="Destination for LyProX-style output file including the consensus", ) parser.add_argument( - "-c", "--consensus", nargs="+", default=["max_llh"], + "-c", + "--consensus", + nargs="+", + default=["max_llh"], choices=CONSENSUS_FUNCS.keys(), - help="Choose consensus method(s)" + help="Choose consensus method(s)", ) parser.add_argument( - "-p", "--params", default="params.yaml", type=Path, - help="Path to parameter file" + "-p", + "--params", + default="params.yaml", + type=Path, + help="Path to parameter file", ) parser.add_argument( - "--modalities", nargs="+", + "--modalities", + nargs="+", default=["CT", "MRI", "PET", "FNA", "diagnostic_consensus", "pathology", "pCT"], - help="List of modalities for enhancement. Must be defined in `params.yaml`" + help="List of modalities for enhancement. Must be defined in `params.yaml`", ) parser.add_argument( - "--sublvls", nargs="+", default=["a", "b"], - help="Indicate what kinds of sublevels exist" + "--sublvls", + nargs="+", + default=["a", "b"], + help="Indicate what kinds of sublevels exist", ) parser.add_argument( - "--lnls-with-sub", nargs="+", default=["I", "II", "V"], - help="List of LNLs where sublevel reporting has been performed or is common" + "--lnls-with-sub", + nargs="+", + default=["I", "II", "V"], + help="List of LNLs where sublevel reporting has been performed or is common", ) parser.set_defaults(run_main=main) @@ -90,13 +99,11 @@ def get_sublvl_values( lnl: str, sub_ids: list[str], ): - """ - Get values of sublevels (e.g. 'IIa' and 'IIb') for an ``lnl`` and ``data_frame``. - """ - has_sublvls = all(lnl+sub in data_frame for sub in sub_ids) + """Get values of sublevels (e.g. 'IIa' and 'IIb') for an ``lnl`` and ``data_frame``.""" + has_sublvls = all(lnl + sub in data_frame for sub in sub_ids) if not has_sublvls: return None - return data_frame[[lnl+sub for sub in sub_ids]].values + return data_frame[[lnl + sub for sub in sub_ids]].values @log_state() @@ -124,25 +131,25 @@ def infer_superlvl_from_sublvls( for mod in modalities: for side in ["ipsi", "contra"]: for lnl in lnls_with_sub: - sublvl_values = get_sublvl_values( - table[mod,side], lnl, sublvls - ) + sublvl_values = get_sublvl_values(table[mod, side], lnl, sublvls) if sublvl_values is None: continue # sometimes, the sublevels both report `False` (healthy) but the # superlevel is involved. In this case, we want to keep the superlevel # as involved. - if lnl in table[mod,side]: - is_superlvl_involved = table[mod,side,lnl] == True + if lnl in table[mod, side]: + is_superlvl_involved = table[mod, side, lnl] == True else: is_superlvl_involved = False - has_sublvl_involved = np.any(sublvl_values==True , axis=1) - all_sublvls_healthy = np.all(sublvl_values==False, axis=1) + has_sublvl_involved = np.any(sublvl_values == True, axis=1) + all_sublvls_healthy = np.all(sublvl_values == False, axis=1) - fixed_table.loc[has_sublvl_involved, (mod,side,lnl)] = True - fixed_table.loc[all_sublvls_healthy & ~is_superlvl_involved, (mod,side,lnl)] = False + fixed_table.loc[has_sublvl_involved, (mod, side, lnl)] = True + fixed_table.loc[ + all_sublvls_healthy & ~is_superlvl_involved, (mod, side, lnl) + ] = False return fixed_table @@ -161,7 +168,7 @@ def get_lnl_observations( for mod in modalities.keys(): try: - add_obs = patient[mod,side,lnl] + add_obs = patient[mod, side, lnl] add_obs = None if pd.isna(add_obs) else add_obs except KeyError: add_obs = None @@ -191,40 +198,32 @@ def and_consensus(obs_tuple: tuple[np.ndarray]): if has_all_none(obs_tuple): return None - return not( - any(not(obs) if obs is not None else None for obs in obs_tuple) - ) + return not (any(not (obs) if obs is not None else None for obs in obs_tuple)) @lru_cache -def maxllh_consensus( - obs_tuple: tuple[np.ndarray], - modalities_spsn: tuple[list[float]] -): +def maxllh_consensus(obs_tuple: tuple[np.ndarray], modalities_spsn: tuple[list[float]]): """Compute the maximum likelihood consensus of different diagnostic modalities.""" if has_all_none(obs_tuple): return None - healthy_llh = 1. - involved_llh = 1. - for obs, spsn in zip(obs_tuple, modalities_spsn): + healthy_llh = 1.0 + involved_llh = 1.0 + for obs, spsn in zip(obs_tuple, modalities_spsn, strict=False): if obs is None: continue spsn = np.array(spsn) obs = int(obs) - spsn2x2 = np.diag(spsn) + np.diag(1. - spsn)[::-1] - healthy_llh *= spsn2x2[obs,0] - involved_llh *= spsn2x2[obs,1] + spsn2x2 = np.diag(spsn) + np.diag(1.0 - spsn)[::-1] + healthy_llh *= spsn2x2[obs, 0] + involved_llh *= spsn2x2[obs, 1] healthy_vs_involved = np.array([healthy_llh, involved_llh]) return bool(np.argmax(healthy_vs_involved)) @lru_cache -def rank_consensus( - obs_tuple: tuple[np.ndarray], - modalities_spsn: tuple[list[float]] -): +def rank_consensus(obs_tuple: tuple[np.ndarray], modalities_spsn: tuple[list[float]]): """Compute the ranked consensus of different diagnostic modalities.""" if has_all_none(obs_tuple): return None @@ -232,12 +231,12 @@ def rank_consensus( modalities_spsn = list(modalities_spsn) healthy_sens = [ - modalities_spsn[i][1] for i,obs in enumerate(obs_tuple) if obs == False + modalities_spsn[i][1] for i, obs in enumerate(obs_tuple) if obs == False ] involved_spec = [ - modalities_spsn[i][0] for i,obs in enumerate(obs_tuple) if obs == True + modalities_spsn[i][0] for i, obs in enumerate(obs_tuple) if obs == True ] - if np.max([*healthy_sens, 0.]) > np.max([*involved_spec, 0.]): + if np.max([*healthy_sens, 0.0]) > np.max([*involved_spec, 0.0]): return False return True @@ -261,20 +260,18 @@ def main(args: argparse.Namespace): selection=args.modalities, ) - available_mod_keys = sorted(set( - input_table.columns.get_level_values(0) - ).intersection( - modalities.keys() - )) + available_mod_keys = sorted( + set(input_table.columns.get_level_values(0)).intersection(modalities.keys()) + ) available_mods = {key: modalities[key] for key in available_mod_keys} - lnl_union = sorted(set().union( - *[input_table[mod,"ipsi"].columns for mod in available_mod_keys] - )) + lnl_union = sorted( + set().union(*[input_table[mod, "ipsi"].columns for mod in available_mod_keys]) + ) consensus = pd.DataFrame( index=input_table.index, columns=pd.MultiIndex.from_product( [args.consensus, ["ipsi", "contra"], lnl_union] - ) + ), ) with CustomProgress(console=console) as report_progress: @@ -284,7 +281,7 @@ def main(args: argparse.Namespace): ) for side in ["ipsi", "contra"]: # go through patients and LNLs and compute consensus for each - for p,patient in input_table.iterrows(): + for p, patient in input_table.iterrows(): for lnl in lnl_union: observations = get_lnl_observations( patient, side, lnl, available_mods @@ -296,12 +293,11 @@ def main(args: argparse.Namespace): report_progress.update(enhance_task, advance=1) table_with_consensus = input_table.join(consensus) - - data_modalities = sorted(set( - table_with_consensus.columns.get_level_values(0) - ).intersection( - [*modalities.keys(), *args.consensus] - )) + data_modalities = sorted( + set(table_with_consensus.columns.get_level_values(0)).intersection( + [*modalities.keys(), *args.consensus] + ) + ) consensus_and_fixed_sublvlvs = infer_superlvl_from_sublvls( table_with_consensus, data_modalities, diff --git a/lyscripts/data/filter.py b/lyscripts/data/filter.py index 7eb7cb0..df5b6a7 100644 --- a/lyscripts/data/filter.py +++ b/lyscripts/data/filter.py @@ -1,7 +1,7 @@ -""" -Filter a datset according to some common criteria, like tumor location, subsite, +"""Filter a datset according to some common criteria, like tumor location, subsite, T-category, etc. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -44,19 +44,20 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "input", type=Path, - help="The path to the full dataset to split." + "input", type=Path, help="The path to the full dataset to split." ) parser.add_argument( - "output", type=Path, - help="Folder to store the split CSV files in." + "output", type=Path, help="Folder to store the split CSV files in." ) for prefix in ["include", "exclude"]: for filter_by in ["locations", "subsites", "t_categories"]: parser.add_argument( - f"--{prefix}-{filter_by}", default=None, type=str, nargs="+", - help=f"If provided, {prefix} patients with the given tumor {filter_by}." + f"--{prefix}-{filter_by}", + default=None, + type=str, + nargs="+", + help=f"If provided, {prefix} patients with the given tumor {filter_by}.", ) parser.set_defaults(run_main=main) diff --git a/lyscripts/data/generate.py b/lyscripts/data/generate.py index 36c5054..2ce6109 100644 --- a/lyscripts/data/generate.py +++ b/lyscripts/data/generate.py @@ -1,8 +1,8 @@ -""" -Calls the synthetic data generating methods of the `lymph`_ package models. +"""Calls the synthetic data generating methods of the `lymph`_ package models. .. _lymph: https://lymph-model.readthedocs.io """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -34,32 +34,42 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "num", type=int, + "num", + type=int, help="Number of synthetic patient records to generate", ) parser.add_argument( - "output", type=Path, + "output", + type=Path, help="Path where to store the generated synthetic data", ) parser.add_argument( - "--params", default="./params.yaml", type=Path, - help="Parameter file containing model specifications" + "--params", + default="./params.yaml", + type=Path, + help="Parameter file containing model specifications", ) group = parser.add_mutually_exclusive_group() group.add_argument( - "--set-theta", nargs="+", type=float, - help="Set the spread probs and parameters for time marginalization by hand" + "--set-theta", + nargs="+", + type=float, + help="Set the spread probs and parameters for time marginalization by hand", ) group.add_argument( - "--load-theta", choices=["mean", "max_llh"], default="mean", - help="Use either the mean or the maximum likelihood estimate from drawn samples" + "--load-theta", + choices=["mean", "max_llh"], + default="mean", + help="Use either the mean or the maximum likelihood estimate from drawn samples", ) parser.add_argument( - "--samples", default="./models/samples.hdf5", type=Path, - help="Path to the samples if a method to load them was chosen" + "--samples", + default="./models/samples.hdf5", + type=Path, + help="Path to the samples if a method to load them was chosen", ) parser.set_defaults(run_main=main) @@ -82,11 +92,7 @@ def main(args: argparse.Namespace): logger.info("Assigned given parameters to model") else: - backend = emcee.backends.HDFBackend( - args.samples, - read_only=True, - name="mcmc" - ) + backend = emcee.backends.HDFBackend(args.samples, read_only=True, name="mcmc") chain = backend.get_chain(flat=True) log_probs = backend.get_blobs(flat=True) diff --git a/lyscripts/data/join.py b/lyscripts/data/join.py index ef91e81..7c40d0d 100644 --- a/lyscripts/data/join.py +++ b/lyscripts/data/join.py @@ -1,6 +1,5 @@ -""" -Join datasets from different sources (but of the same format) into one. -""" +"""Join datasets from different sources (but of the same format) into one.""" + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -35,12 +34,19 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "-i", "--inputs", nargs='+', type=Path, required=True, - help="List of paths to inference-ready CSV datasets to concatenate." + "-i", + "--inputs", + nargs="+", + type=Path, + required=True, + help="List of paths to inference-ready CSV datasets to concatenate.", ) parser.add_argument( - "-o", "--output", type=Path, required=True, - help="Location to store the concatenated CSV file." + "-o", + "--output", + type=Path, + required=True, + help="Location to store the concatenated CSV file.", ) parser.set_defaults(run_main=main) @@ -52,9 +58,7 @@ def load_and_join_tables(input_paths: list[Path]): for path in input_paths: input_table = load_patient_data(path).convert_dtypes() concatenated_table = pd.concat( - [concatenated_table, input_table], - axis="index", - ignore_index=True + [concatenated_table, input_table], axis="index", ignore_index=True ) logger.info(f"+ concatenated data from {path}") return concatenated_table diff --git a/lyscripts/data/lyproxify.py b/lyscripts/data/lyproxify.py index ad11b46..8d0a17e 100644 --- a/lyscripts/data/lyproxify.py +++ b/lyscripts/data/lyproxify.py @@ -1,5 +1,4 @@ -""" -Consumes raw data and transforms it into a CSV of the format that `LyProX`_ understands. +"""Consumes raw data and transforms it into a CSV of the format that `LyProX`_ understands. To do so, it needs a dictionary that defines a mapping from raw columns to the LyProX style data format. See the documentation of the :py:func:`.transform_to_lyprox` function @@ -7,6 +6,7 @@ .. _LyProX: https://lyprox.org """ + # pylint: disable=logging-fstring-interpolation import argparse import importlib.util @@ -44,38 +44,54 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "-i", "--input", type=Path, required=True, - help="Location of raw CSV data." + "-i", "--input", type=Path, required=True, help="Location of raw CSV data." ) parser.add_argument( - "-r", "--header-rows", nargs="+", default=[0], type=int, - help="List with header row indices of raw file." + "-r", + "--header-rows", + nargs="+", + default=[0], + type=int, + help="List with header row indices of raw file.", ) parser.add_argument( - "-o", "--output", type=Path, required=True, - help="Location to store the lyproxified CSV file." + "-o", + "--output", + type=Path, + required=True, + help="Location to store the lyproxified CSV file.", ) parser.add_argument( - "-m", "--mapping", type=Path, required=True, + "-m", + "--mapping", + type=Path, + required=True, help=( "Location of the Python file that contains column mapping instructions. " "This must contain a dictionary with the name 'column_map'." - ) + ), ) parser.add_argument( - "--drop-rows", nargs="+", type=int, default=[], + "--drop-rows", + nargs="+", + type=int, + default=[], help=( "Delete rows of specified indices. Counting of rows start at 0 _after_ " "the `header-rows`." - ) + ), ) parser.add_argument( - "--drop-cols", nargs="+", type=int, default=[], + "--drop-cols", + nargs="+", + type=int, + default=[], help="Delete columns of specified indices.", ) parser.add_argument( - "--add-index", action="store_true", - help="If the data doesn't contain an index, add one by enumerating the patients" + "--add-index", + action="store_true", + help="If the data doesn't contain an index, add one by enumerating the patients", ) parser.set_defaults(run_main=main) @@ -170,8 +186,7 @@ def generate_markdown_docs( @log_state() def transform_to_lyprox( - raw: pd.DataFrame, - column_map: dict[tuple, dict[str, Any]] + raw: pd.DataFrame, column_map: dict[tuple, dict[str, Any]] ) -> pd.DataFrame: """Transform ``raw`` data frame into table that can be uploaded directly to `LyProX`_. @@ -249,21 +264,15 @@ def leftright_to_ipsicontra(data: pd.DataFrame): involvement. """ len_before = len(data) - left_data = data.loc[ - data["tumor", "1", "side"] != "right" - ] - right_data = data.loc[ - data["tumor", "1", "side"] == "right" - ] + left_data = data.loc[data["tumor", "1", "side"] != "right"] + right_data = data.loc[data["tumor", "1", "side"] == "right"] left_data = left_data.rename(columns={"left": "ipsi"}, level=1) left_data = left_data.rename(columns={"right": "contra"}, level=1) right_data = right_data.rename(columns={"left": "contra"}, level=1) right_data = right_data.rename(columns={"right": "ipsi"}, level=1) - data = pd.concat( - [left_data, right_data], ignore_index=True - ) + data = pd.concat([left_data, right_data], ignore_index=True) assert len_before == len(data), "Number of patients changed" return data @@ -295,7 +304,9 @@ def exclude_patients(raw: pd.DataFrame, exclude: list[tuple[str, Any]]): def main(args: argparse.Namespace): """Run the lyproxify main function.""" raw: pd.DataFrame = load_patient_data(args.input, header=args.header_rows) - raw = clean_header(raw, num_cols=raw.shape[1], num_header_rows=len(args.header_rows)) + raw = clean_header( + raw, num_cols=raw.shape[1], num_header_rows=len(args.header_rows) + ) cols_to_drop = raw.columns[args.drop_cols] trimmed = raw.drop(cols_to_drop, axis="columns") diff --git a/lyscripts/data/split.py b/lyscripts/data/split.py index 0f224ea..a7206d1 100644 --- a/lyscripts/data/split.py +++ b/lyscripts/data/split.py @@ -1,7 +1,7 @@ -""" -Split the full dataset into cross-validation folds according to the +"""Split the full dataset into cross-validation folds according to the content of the params.yaml file. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -34,17 +34,18 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "input", type=Path, - help="The path to the full dataset to split." + "input", type=Path, help="The path to the full dataset to split." ) parser.add_argument( - "output", type=Path, - help="Folder to store the split CSV files in." + "output", type=Path, help="Folder to store the split CSV files in." ) parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter YAML file." + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter YAML file.", ) parser.set_defaults(run_main=main) @@ -58,25 +59,20 @@ def main(args: argparse.Namespace): args.output.mkdir(exist_ok=True) shuffled_df = concatenated_df.sample( - frac=1., - replace=False, - random_state=params["cross-validation"]["seed"] + frac=1.0, replace=False, random_state=params["cross-validation"]["seed"] ).reset_index(drop=True) - split_dfs = np.array_split( - shuffled_df, - len(params["cross-validation"]["folds"]) - ) + split_dfs = np.array_split(shuffled_df, len(params["cross-validation"]["folds"])) for fold_id, split_pattern in params["cross-validation"]["folds"].items(): # Concatenate training and evaluation DataFrames from the split DataFrames # using the split pattern defined in the params file. eval_df = pd.concat( - [split_dfs[k] for k,sym in enumerate(split_pattern) if sym == "e"], - ignore_index=True + [split_dfs[k] for k, sym in enumerate(split_pattern) if sym == "e"], + ignore_index=True, ) train_df = pd.concat( - [split_dfs[k] for k,sym in enumerate(split_pattern) if sym == "t"], - ignore_index=True + [split_dfs[k] for k, sym in enumerate(split_pattern) if sym == "t"], + ignore_index=True, ) eval_df.to_csv(args.output / f"{fold_id}_eval.csv", index=None) diff --git a/lyscripts/data/utils.py b/lyscripts/data/utils.py index a1e724c..32803b7 100644 --- a/lyscripts/data/utils.py +++ b/lyscripts/data/utils.py @@ -1,6 +1,5 @@ -""" -Utilities related to the commands for data cleaning and processing. -""" +"""Utilities related to the commands for data cleaning and processing.""" + from pathlib import Path import pandas as pd diff --git a/lyscripts/decorators.py b/lyscripts/decorators.py index bf6f292..11ed218 100644 --- a/lyscripts/decorators.py +++ b/lyscripts/decorators.py @@ -1,9 +1,9 @@ -""" -This module provides decorators that can be used to avoid repetitive snippets of code, +"""This module provides decorators that can be used to avoid repetitive snippets of code, e.g. safely opening files or logging the state of a function call. This is *not* a command line tool. """ + import functools import logging from collections.abc import Callable @@ -66,10 +66,12 @@ def log_state(log_level: int = logging.INFO) -> Callable: The log message will simply be the function name where underscores are replaced with spaces. The `log_level` can be set in the decorator call. """ + # pylint: disable=logging-fstring-interpolation # pylint: disable=logging-not-lazy def log_decorator(func: Callable): """The decorator wrapping the decorated function.""" + @functools.wraps(func) def wrapper(*args, **kwargs): """The wrapper around the decorated function.""" @@ -101,6 +103,7 @@ def wrapper(*args, **kwargs): def check_input_file_exists(loading_func: Callable) -> Callable: """Check if the file path provided to the `loading_func` exists.""" + @wraps(loading_func) def inner(file_path: str, *args, **kwargs) -> Any: """Wrapped loading function.""" @@ -115,6 +118,7 @@ def inner(file_path: str, *args, **kwargs) -> Any: def check_output_dir_exists(saving_func: Callable) -> Callable: """Make sure the parent directory of the saved file exists.""" + @wraps(saving_func) def inner(file_path: str, *args, **kwargs) -> Any: """Wrapped saving function.""" diff --git a/lyscripts/evaluate.py b/lyscripts/evaluate.py index bd1fbec..1fb8522 100644 --- a/lyscripts/evaluate.py +++ b/lyscripts/evaluate.py @@ -1,8 +1,8 @@ -""" -Evaluate the performance of the trained model by computing quantities like the +"""Evaluate the performance of the trained model by computing quantities like the Bayesian information criterion (BIC) or (if thermodynamic integration was performed) the actual evidence (with error) of the model. """ + # pylint: disable=logging-fstring-interpolation import argparse import json @@ -40,37 +40,29 @@ def _add_arguments(parser: argparse.ArgumentParser): This is called by the parent module that is called via the command line. """ parser.add_argument( - "data", type=Path, - help="Path to the tables of patient data (CSV)." - ) - parser.add_argument( - "model", type=Path, - help="Path to model output files (HDF5)." + "data", type=Path, help="Path to the tables of patient data (CSV)." ) + parser.add_argument("model", type=Path, help="Path to model output files (HDF5).") parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file" + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file", ) parser.add_argument( - "--plots", default="./plots", type=Path, - help="Directory for storing plots" + "--plots", default="./plots", type=Path, help="Directory for storing plots" ) parser.add_argument( - "--metrics", default="./metrics.json", type=Path, - help="Path to metrics file" + "--metrics", default="./metrics.json", type=Path, help="Path to metrics file" ) parser.set_defaults(run_main=main) -def comp_bic( - log_probs: np.ndarray, - num_params: int, - num_data: int -) -> float: - """ - Compute the negative one half of the Bayesian Information Criterion (BIC). +def comp_bic(log_probs: np.ndarray, num_params: int, num_data: int) -> float: + """Compute the negative one half of the Bayesian Information Criterion (BIC). The BIC is defined as [^1] $$ BIC = k \\ln{n} - 2 \\ln{\\hat{L}} $$ @@ -83,7 +75,8 @@ def comp_bic( [^1]: https://en.wikipedia.org/wiki/Bayesian_information_criterion """ - return np.max(log_probs) - num_params * np.log(num_data) / 2. + return np.max(log_probs) - num_params * np.log(num_data) / 2.0 + def compute_evidence( temp_schedule: np.ndarray, @@ -102,7 +95,7 @@ def compute_evidence( integrals = np.zeros(shape=num) for i in range(num): rand_idx = np.random.choice(log_probs.shape[1], size=log_probs.shape[0]) - drawn_accuracy = log_probs[np.arange(log_probs.shape[0]),rand_idx].copy() + drawn_accuracy = log_probs[np.arange(log_probs.shape[0]), rand_idx].copy() integrals[i] = trapezoid(y=drawn_accuracy, x=temp_schedule) return np.mean(integrals), np.std(integrals) @@ -167,17 +160,18 @@ def main(args: argparse.Namespace): args.plots.parent.mkdir(exist_ok=True) beta_vs_accuracy = pd.DataFrame( - np.array([ - temp_schedule, - np.mean(ti_log_probs, axis=1), - np.std(ti_log_probs, axis=1) - ]).T, + np.array( + [ + temp_schedule, + np.mean(ti_log_probs, axis=1), + np.std(ti_log_probs, axis=1), + ] + ).T, columns=["β", "accuracy", "std"], ) beta_vs_accuracy.to_csv(args.plots, index=False) logger.info(f"Plotted β vs accuracy at {args.plots}") - # use blobs, because also for TI, this is the unscaled log-prob backend = emcee.backends.HDFBackend(args.model, read_only=True, name="mcmc") final_log_probs = backend.get_blobs() @@ -188,7 +182,9 @@ def main(args: argparse.Namespace): args.metrics.touch(exist_ok=True) metrics["BIC"] = comp_bic( - final_log_probs, ndim, len(data), + final_log_probs, + ndim, + len(data), ) metrics["max_llh"] = np.max(final_log_probs) metrics["mean_llh"] = np.mean(final_log_probs) diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py new file mode 100644 index 0000000..7dd30c3 --- /dev/null +++ b/lyscripts/mixture_fit.py @@ -0,0 +1,316 @@ +"""Learn the spread probabilities of the HMM for lymphatic tumor progression using +the preprocessed data as input and the mixture model. +""" + +# pylint: disable=logging-fstring-interpolation +import argparse +import logging +import os +from concurrent.futures import ProcessPoolExecutor +from pathlib import Path + +import numpy as np +import pandas as pd +from lymixture.em import expectation, maximization +from lymph import models + +from lyscripts.utils import ( + assign_modalities, + create_mixture, + load_patient_data, + load_yaml_params, + to_numpy, +) + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to a ``subparsers`` instance and run its main function when chosen. + + This is called by the parent module that is called via the command line. + """ + parser.add_argument("-i", "--input", type=Path, help="Path to training data files") + parser.add_argument( + "--history", + type=Path, + nargs="?", + help="Path to store history in (as CSV file).", + ) + parser.add_argument( + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=42, + help="Seed value to reproduce the same sampling round.", + ) + parser.add_argument( + "-m", + "--multi_fit", + type=bool, + default=False, + help="Whether to fit multiple models for later uncertainty evaluation", + ) + parser.add_argument( + "-sp", + "--starting_point", + type=Path, + default=None, + help="Starting point for optimization if we do not want to start from a random point", + ) + + parser.set_defaults(run_main=main) + + +MIXTURE = None + + +def log_prob_fn() -> float: + """Log probability function using global variables because of pickling.""" + return MIXTURE.likelihood( + use_complete=True, given_resps=MIXTURE.get_resps(norm=True) + ) + + +def check_convergence( + params_history, likelihood_history, steps_back_list, absolute_tolerance=0.01 +): + current_params = params_history[-1] + current_likelihood = likelihood_history[-1] + for steps_back in steps_back_list: + previous_params = params_history[-steps_back - 1] + if np.allclose(to_numpy(current_params), to_numpy(previous_params)): + logger.info( + f"Converged after {len(params_history)} steps. due to parameter similarity" + ) + return True # Return True if any of the steps is close + elif np.isclose( + current_likelihood, + likelihood_history[-steps_back - 1], + rtol=0, + atol=absolute_tolerance, + ): + logger.info( + f"Converged after {len(params_history)} steps. due to likelihood similarity" + ) + return True + return False + + +def run_EM(tolerance, history_dir=None): + """Run the EM algorithm to determine the optimal parameters.""" + os.makedirs(history_dir, exist_ok=True) + is_converged = False + iteration = 0 + params = MIXTURE.get_params() + params_history = [] + likelihood_history = [] + params_history.append(params.copy()) + likelihood_history.append(MIXTURE.likelihood(use_complete=False)) + # Number of steps to look back for convergence + look_back_steps = 3 + + while not is_converged: + logger.info(f"Iteration: {iteration}") + logger.info(f"Likelihood: {likelihood_history[-1]}") + latent = expectation(MIXTURE, params, log=True) + MIXTURE.set_resps(np.exp(latent)) + params = maximization(MIXTURE, latent) + + # Append current params and likelihood to history + params_history.append(params.copy()) + likelihood_history.append(MIXTURE.likelihood(use_complete=False)) + if history_dir != None: + llh_history = pd.DataFrame(likelihood_history) + llh_history.columns = ["likelihoods"] + llh_history.to_csv(history_dir + "/llh.csv", index=False) + param_history = pd.DataFrame(params_history) + param_history.to_csv(history_dir + "/params.csv", index=False) + MIXTURE.get_mixture_coefs().to_csv( + history_dir + "/mixture_coef.csv", index=False + ) + # Check if converged + if iteration >= 3: # Ensure enough history is available + is_converged = check_convergence( + params_history, + likelihood_history, + list(range(1, look_back_steps + 1)), + tolerance, + ) + iteration += 1 + df = pd.DataFrame.from_dict(MIXTURE.get_params(), orient="index", columns=["value"]) + df.to_csv(history_dir + "/optimal_params.csv") + return params_history, likelihood_history + + +def process_dataset( + dataset, folder_path, initial_params, model_build_params, index, look_back_steps=3 +): + os.makedirs(folder_path, exist_ok=True) + subpath_optimal_params = "optimal_params" + os.makedirs(os.path.join(folder_path, subpath_optimal_params), exist_ok=True) + subpath_params_history = "params_history" + os.makedirs(os.path.join(folder_path, subpath_params_history), exist_ok=True) + subpath_likelihood_history = "likelihood_history" + os.makedirs(os.path.join(folder_path, subpath_likelihood_history), exist_ok=True) + + logger.info(f"Starting dataset {index}") + mixture = create_mixture(model_build_params) + mapping = model_build_params["model"].get("mapping", None) + if isinstance(mixture.components[0], models.Unilateral): + mixture.load_patient_data( + dataset, + split_by=model_build_params["model"].get( + "split_by", ("tumor", "1", "subsite") + ), + mapping=mapping, + ) + assign_modalities( + model=mixture, config=model_build_params.get("inference_modalities", {}) + ) + else: + raise ValueError("Only Unilateral has been implemented so far") + + mixture.set_params(**initial_params) + mixture.normalize_mixture_coefs() + tolerance = model_build_params["model"].get("likelihood_tolerance", 0.01) + mixture.set_params(**initial_params) + params = initial_params.copy() + + mixture.normalize_mixture_coefs() + params_history = [params.copy()] + likelihood_history = [mixture.likelihood(use_complete=False)] + + is_converged = False + count = 0 + logger.info(f"[Dataset {index}] started") + file_prefix = f"dataset_{index}" + + while not is_converged: + latent = expectation(mixture, params, log=True) + mixture.set_resps(np.exp(latent)) + params = maximization(mixture, latent) + + params_history.append(params.copy()) + likelihood_history.append(mixture.likelihood(use_complete=False)) + + llh_history = pd.DataFrame(likelihood_history) + llh_history.columns = ["likelihoods"] + llh_history.to_csv( + os.path.join( + folder_path, + subpath_likelihood_history, + f"{file_prefix}_likelihood_history.csv", + ), + index=False, + ) + param_history = pd.DataFrame(params_history) + param_history.to_csv( + os.path.join( + folder_path, subpath_params_history, f"{file_prefix}_param_history.csv" + ), + index=False, + ) + if count >= look_back_steps: + is_converged = check_convergence( + params_history, + likelihood_history, + list(range(1, look_back_steps + 1)), + tolerance, + ) + + count += 1 + + logger.info(f"[Dataset {index}] Converged after {count+1} steps") + + df = pd.DataFrame.from_dict(mixture.get_params(), orient="index", columns=["value"]) + df.to_csv( + os.path.join( + folder_path, subpath_optimal_params, f"{file_prefix}_optimal_params.csv" + ) + ) + + +def main(args: argparse.Namespace) -> None: + """Main function to run the EM algorithm for a mixture model""" + params = load_yaml_params(args.params) + global MIXTURE + MIXTURE = create_mixture(params) + original_data = load_patient_data(params["general"]["data"]) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE.components[0], models.Unilateral): + MIXTURE.load_patient_data( + original_data, + split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), + mapping=mapping, + ) + assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) + + else: + raise ValueError("Only Unilateral has been implemented so far") + + if args.starting_point is None: + rng = np.random.default_rng(params["em"].get("seed", 42)) + starting_values = {k: rng.uniform() for k in MIXTURE.get_params()} + else: + logger.info(f"Using starting point from {args.starting_point}") + starting_df = pd.read_csv( + args.starting_point, index_col=0 + ) # Use first column as index + starting_values = starting_df["value"].to_dict() + + if args.multi_fit: + datasets = [] + history_dir = params["sampling"]["output_path"] + os.makedirs(history_dir, exist_ok=True) + for i in range(params["sampling"]["n_bootstraps"]): + file_path = os.path.join(args.input, f"dataset_resample_{i}.csv") + if os.path.exists(file_path): + loaded_dataset = pd.read_csv(file_path, header=[0, 1, 2]) + datasets.append(loaded_dataset) + with ProcessPoolExecutor(max(1, os.cpu_count() - 2)) as executor: + futures = [ + executor.submit( + process_dataset, dataset, history_dir, starting_values, params, i + ) + for i, dataset in enumerate(datasets) + ] + else: + MIXTURE.set_params(**starting_values) + MIXTURE.normalize_mixture_coefs() + tolerance = params["model"].get("likelihood_tolerance", 0.01) + history_dir = params["fitting"]["folder_path"] + logger.info(f"Saving history to {history_dir}.") + params_history, likelihood_history = run_EM( + tolerance=tolerance, history_dir=history_dir + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/mixture_sample_old.py b/lyscripts/mixture_sample_old.py new file mode 100644 index 0000000..afb00fb --- /dev/null +++ b/lyscripts/mixture_sample_old.py @@ -0,0 +1,142 @@ +"""Learn the spread probabilities of the HMM for lymphatic tumor progression using +the preprocessed data as input and the mixture model. +""" + +# pylint: disable=logging-fstring-interpolation +import argparse +import logging + +try: + from multiprocess import Pool +except ModuleNotFoundError: + pass + +from pathlib import Path + +import pandas as pd +from lymixture.em import expectation, sample_fixed_mixture, sample_model_params +from lymph import models + +from lyscripts.utils import ( + assign_modalities, + create_mixture, + load_patient_data, + load_yaml_params, +) + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to a ``subparsers`` instance and run its main function when chosen. + + This is called by the parent module that is called via the command line. + """ + parser.add_argument( + "-m", + "--mixture_coefs", + type=Path, + required=True, + help="File path of mixture coefficients", + ) + parser.add_argument( + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", + ) + parser.add_argument( + "-mp", + "--model_params", + type=Path, + required=True, + help="File path of mixture coefficients", + ) + parser.add_argument( + "--mode", + type=str, + default="fixed_mixture", + help="Mode of sampling. Use either 'fixed_mixture' or 'fixed_latent'", + ) + parser.add_argument("-o", "--output", type=Path, help="Output path for samples") + parser.add_argument( + "-d", "--data", type=Path, required=True, help="Path to the data file." + ) + parser.add_argument( + "-c", + "--continue_sampling", + type=bool, + default=False, + help="Continue sampling from previous run stored in the output backend", + ) + + parser.set_defaults(run_main=main) + + +MIXTURE = None + + +def main(args: argparse.Namespace) -> None: + """Main function to sample parameters for a mixture model""" + params = load_yaml_params(args.params) + model_params = pd.read_csv(args.model_params, header=[0]) + inference_data = load_patient_data(args.data) + param_dict = dict(model_params.iloc[-1]) + # ugly, but necessary for pickling + global MIXTURE + MIXTURE = create_mixture(params) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE.components[0], models.Unilateral): + side = params["model"].get("side", "ipsi") + MIXTURE.load_patient_data( + inference_data, + split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), + mapping=mapping, + ) + assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) + + else: + raise "Only Unilateral has been implemented so far" + + MIXTURE.set_params(**param_dict) + MIXTURE.set_resps(expectation(MIXTURE, param_dict)) + if args.mode == "fixed_mixture": + backend, samples = sample_fixed_mixture( + MIXTURE, + steps=params["sampling"].get( + "steps", + ), + filename=str(args.output) + "/fixed_mixture.hdf5", + continue_sampling=args.continue_sampling, + ) + elif args.mode == "fixed_latent": + backend, samples = sample_model_params( + MIXTURE, + steps=params["sampling"].get("steps"), + filename=str(args.output) + "/fixed_latent.hdf5", + continue_sampling=args.continue_sampling, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/plot/__init__.py b/lyscripts/plot/__init__.py index 6bc0748..22343d1 100644 --- a/lyscripts/plot/__init__.py +++ b/lyscripts/plot/__init__.py @@ -1,12 +1,20 @@ -""" -Provide various plotting utilities for displaying results of e.g. the inference +"""Provide various plotting utilities for displaying results of e.g. the inference or prediction process. At the moment, three subcommands are grouped under :py:mod:`.plot`. """ + import argparse from pathlib import Path -from lyscripts.plot import corner, histograms, thermo_int +from lyscripts.plot import ( + corner, + histograms, + mixture_comp_uncertainty, + mixture_plot, + mixture_sampling_plotter, + simplex_plot, + thermo_int, +) def _add_parser( @@ -24,3 +32,11 @@ def _add_parser( corner._add_parser(subparsers, help_formatter=parser.formatter_class) histograms._add_parser(subparsers, help_formatter=parser.formatter_class) thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) + simplex_plot._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_sampling_plotter._add_parser( + subparsers, help_formatter=parser.formatter_class + ) + mixture_comp_uncertainty._add_parser( + subparsers, help_formatter=parser.formatter_class + ) diff --git a/lyscripts/plot/__main__.py b/lyscripts/plot/__main__.py index dc34a9f..9d89828 100644 --- a/lyscripts/plot/__main__.py +++ b/lyscripts/plot/__main__.py @@ -1,7 +1,15 @@ import argparse from lyscripts import RichDefaultHelpFormatter, exit_cli -from lyscripts.plot import corner, histograms, thermo_int +from lyscripts.plot import ( + corner, + histograms, + mixture_comp_uncertainty, + mixture_plot, + mixture_sampling_plotter, + simplex_plot, + thermo_int, +) # I need another __main__ guard here, because otherwise pdoc tries to run this if __name__ == "__main__": @@ -18,6 +26,14 @@ corner._add_parser(subparsers, help_formatter=parser.formatter_class) histograms._add_parser(subparsers, help_formatter=parser.formatter_class) thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) + simplex_plot._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_sampling_plotter._add_parser( + subparsers, help_formatter=parser.formatter_class + ) + mixture_comp_uncertainty._add_parser( + subparsers, help_formatter=parser.formatter_class + ) args = parser.parse_args() args.run_main(args) diff --git a/lyscripts/plot/corner.py b/lyscripts/plot/corner.py index cffd5d4..1134edf 100644 --- a/lyscripts/plot/corner.py +++ b/lyscripts/plot/corner.py @@ -1,5 +1,4 @@ -""" -Generate a corner plot of the drawn samples. +"""Generate a corner plot of the drawn samples. A corner plot is a combination of 1D and 2D marginals of probability distributions. The library I use for this is built on `matplotlib` and is called @@ -7,6 +6,7 @@ .. _corner: https://github.com/dfm/corner.py """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -37,17 +37,14 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" + parser.add_argument("model", type=Path, help="Path to model output files (HDF5).") + parser.add_argument("output", type=Path, help="Path to output corner plot (SVG).") parser.add_argument( - "model", type=Path, - help="Path to model output files (HDF5)." - ) - parser.add_argument( - "output", type=Path, - help="Path to output corner plot (SVG)." - ) - parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file" + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file", ) parser.set_defaults(run_main=main) diff --git a/lyscripts/plot/histograms.py b/lyscripts/plot/histograms.py index 5b202fc..f090221 100644 --- a/lyscripts/plot/histograms.py +++ b/lyscripts/plot/histograms.py @@ -1,6 +1,5 @@ -""" -Plot computed risks and prevalences into a beautiful histogram. -""" +"""Plot computed risks and prevalences into a beautiful histogram.""" + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -38,29 +37,27 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "input", type=Path, - help="File path of the computed risks or prevalences (HDF5)" - ) - parser.add_argument( - "output", type=Path, - help="Output path for the plot" + "input", type=Path, help="File path of the computed risks or prevalences (HDF5)" ) + parser.add_argument("output", type=Path, help="Output path for the plot") parser.add_argument( - "--names", nargs="+", - help="List of names of computed risks/prevalences to combine into one plot" - ) - parser.add_argument( - "--title", type=str, - help="Title of the plot" + "--names", + nargs="+", + help="List of names of computed risks/prevalences to combine into one plot", ) + parser.add_argument("--title", type=str, help="Title of the plot") parser.add_argument( - "--bins", default=60, type=int, - help="Number of bins to put the computed values into" + "--bins", + default=60, + type=int, + help="Number of bins to put the computed values into", ) parser.add_argument( - "--mplstyle", default="./.mplstyle", type=Path, - help="Path to the MPL stylesheet" + "--mplstyle", + default="./.mplstyle", + type=Path, + help="Path to the MPL stylesheet", ) parser.set_defaults(run_main=main) @@ -73,18 +70,22 @@ def main(args: argparse.Namespace): contents = [] for name in args.names: color = next(COLOR_CYCLE) - contents.append(Histogram.from_hdf5( - filename=args.input, - dataname=name, - color=color, - )) - logger.info(f"Added histogram {name} to figure") - try: - contents.append(BetaPosterior.from_hdf5( + contents.append( + Histogram.from_hdf5( filename=args.input, dataname=name, color=color, - )) + ) + ) + logger.info(f"Added histogram {name} to figure") + try: + contents.append( + BetaPosterior.from_hdf5( + filename=args.input, + dataname=name, + color=color, + ) + ) except KeyError: logger.warning(f"No observation data available for dataset {name}") else: @@ -95,7 +96,7 @@ def main(args: argparse.Namespace): axes=ax, contents=contents, hist_kwargs={"nbins": args.bins}, - percent_lims=(5., 5.) + percent_lims=(5.0, 5.0), ) ax.legend() logger.info("Drawn figure") diff --git a/lyscripts/plot/mixture_comp_uncertainty.py b/lyscripts/plot/mixture_comp_uncertainty.py new file mode 100644 index 0000000..e027e73 --- /dev/null +++ b/lyscripts/plot/mixture_comp_uncertainty.py @@ -0,0 +1,177 @@ +import argparse +import logging +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from lymph import models + +from lyscripts.plot.utils import COLORS, save_figure +from lyscripts.utils import create_mixture, load_patient_data, load_yaml_params + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument( + "--input", type=Path, help="File path of resampled optimal parameters" + ) + parser.add_argument( + "--optimal_params", + type=Path, + help="File path of the initial optimal parameters", + ) + parser.add_argument("--output", type=Path, help="Output path for the plot") + parser.add_argument( + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", + ) + + parser.set_defaults(run_main=main) + + +def get_stats(keys, best_params_list, optima=None, percentile_range=0.68): + means = [] + lower_errors = [] + upper_errors = [] + for index, key in enumerate(keys): + values = np.array([bp[key] for bp in best_params_list]) + values = sorted(values) + n_values = len(values) + n_within_range = int(np.ceil(percentile_range * n_values)) + if optima is not None: + mean = optima[index] + else: + mean = np.mean(values) + distances = [abs(v - mean) for v in values] + sorted_indices = np.argsort(distances) + closest_indices = sorted_indices[:n_within_range] + # Extract the 68% range + percentile_values = [values[i] for i in closest_indices] + percentile_values.sort() + low_err = max(mean - min(percentile_values), 0) + high_err = max(max(percentile_values) - mean, 0) + means.append(mean) + lower_errors.append(low_err) + upper_errors.append(high_err) + return means, [lower_errors, upper_errors] + + +def plot_mixture_comp_uncertainty(initial_params_path, resampling_path): + all_keys = MIXTURE.get_params().keys() + initial_params = pd.read_csv(initial_params_path, index_col=0)["value"].to_dict() + # List all files in the optimal_params_path directory + optimal_params_files = os.listdir(resampling_path) + + # Open each file and load as a dictionary + best_params_list = [] + for fname in optimal_params_files: + fpath = os.path.join(resampling_path, fname) + params = pd.read_csv(fpath, index_col=0)["value"].to_dict() + best_params_list.append(params) + + num_components = len(MIXTURE.components) + mixture_keys = [key for key in all_keys if "coef" in key] + # Dynamically generate keys for each component based on num_components + keys_per_component = [ + [key for key in mixture_keys if f"{i}_C" in key] for i in range(num_components) + ] + + # Get the optimal parameters as means + MIXTURE.set_params(**initial_params) + optima_list = [] + mixture_coefs = MIXTURE.get_mixture_coefs() + for i in range(num_components): + optima_list.append(mixture_coefs.loc[i].to_list()) + + # Compute stats for each component + means_list = [] + yerr_list = [] + for keys, optima in zip(keys_per_component, optima_list, strict=False): + means, yerr = get_stats(keys, best_params_list, optima) + means_list.append(means) + yerr_list.append(yerr) + + color_list = [] + for lister in optima_list: + extreme_subsite = list(MIXTURE.subgroups.keys())[np.argmax(lister)] + if extreme_subsite in ["C02", "C03", "C04", "C05", "C06"]: + color_list.append(COLORS["blue"]) + elif extreme_subsite in ["C01", "C09", "C10"]: + color_list.append(COLORS["green"]) + elif extreme_subsite in ["C12", "C13"]: + color_list.append(COLORS["red"]) + elif extreme_subsite in ["C32.0", "C32.1", "C32.2"]: + color_list.append(COLORS["orange"]) + # Unpack for plotting for a dynamic number of components + plt.figure(figsize=(10, 6)) + x = np.arange(len(keys_per_component[0])) + for idx, (means, yerr, keys, color) in enumerate( + zip(means_list, yerr_list, keys_per_component, color_list, strict=False) + ): + plt.errorbar( + x, + means, + yerr=yerr, + fmt="o", + color=color, + ecolor=color, + capsize=5, + markersize=8, + label=f"component {idx}", + ) + + plt.xticks(x, MIXTURE.subgroups.keys(), rotation=0, ha="right", fontsize=10) + plt.legend(fontsize=10) + plt.ylabel("Values", fontsize=12) + plt.title("mixture Values Assignment", fontsize=14) + plt.grid(axis="y", linestyle="--", alpha=0.7) + plt.tight_layout() + return plt + + +def main(args: argparse.Namespace): + params = load_yaml_params(args.params) + global MIXTURE + MIXTURE = create_mixture(params) + inference_data = load_patient_data(params["general"]["data"]) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE.components[0], models.Unilateral): + MIXTURE.load_patient_data( + inference_data, + split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), + mapping=mapping, + ) + + plot = plot_mixture_comp_uncertainty(args.optimal_params, args.input) + save_figure(args.output, plot, formats=["png", "svg"]) + logger.info(f"Mixture component uncertainty plot saved to {args.output}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/plot/mixture_plot.py b/lyscripts/plot/mixture_plot.py new file mode 100644 index 0000000..c574eff --- /dev/null +++ b/lyscripts/plot/mixture_plot.py @@ -0,0 +1,82 @@ +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.colors import LinearSegmentedColormap + +from lyscripts.plot.utils import COLORS, SUBSITE_COLORS, save_figure + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument("--input", type=Path, help="File path of mixture coefficients") + parser.add_argument("--output", type=Path, help="Output path for the plot") + + parser.set_defaults(run_main=main) + + +def main(args: argparse.Namespace): + tmp = LinearSegmentedColormap.from_list( + "tmp", [COLORS["green"], COLORS["red"]], N=128 + ) + mixture_df = pd.read_csv(args.input) + filtered_keys = [key for key in SUBSITE_COLORS if key in mixture_df.columns] + mixture_df = mixture_df[filtered_keys] + + # Transpose the matrix to rotate by 90° + matrix_rotated = mixture_df.T + + # Create the figure and axis + fig, ax = plt.subplots(figsize=(8, 10)) + + # Display the rotated matrix using imshow + cax = ax.imshow(matrix_rotated.values, cmap=tmp, origin="upper") + + # Loop over the data and create text annotations + for i in range(matrix_rotated.shape[0]): # Rows (previously columns) + for j in range(matrix_rotated.shape[1]): # Columns (previously rows) + value = matrix_rotated.iloc[i, j] + ax.text( + j, + i, + f"{value:.2f}", + ha="center", + va="center", + color="white", + fontsize=12, + ) + + # Optional: Set axis labels and title + ax.set_xticks(range(matrix_rotated.shape[1])) + ax.set_xticklabels(mixture_df.index, fontsize=12) # Original row labels + ax.set_yticks(range(matrix_rotated.shape[0])) + ax.set_yticklabels(mixture_df.columns, fontsize=12) # Original column labels + ax.set_title("Mixture Coefficients per subsite", fontsize=16) + save_figure(args.output, fig, formats=["png", "svg"]) + logger.info("Mixture parameter matrix saved") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/plot/mixture_sampling_plotter.py b/lyscripts/plot/mixture_sampling_plotter.py new file mode 100644 index 0000000..1e7d232 --- /dev/null +++ b/lyscripts/plot/mixture_sampling_plotter.py @@ -0,0 +1,318 @@ +import argparse +import json +import logging +from pathlib import Path + +import emcee +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import scipy as sp +from cycler import cycler +from lymixture.em import _set_params, expectation +from lymph import models + +from lyscripts.plot.utils import COLORS, save_figure +from lyscripts.utils import ( + assign_modalities, + create_mixture, + load_patient_data, + load_yaml_params, +) + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument( + "--input", type=Path, help="File path with emcee backend of samples" + ) + parser.add_argument("--output", type=Path, help="Output path for the plot") + parser.add_argument( + "-m", + "--mode", + type=str, + default="fixed_mixture", + help="Mode of sampling. Use either 'fixed_mixture' or 'fixed_latent'", + ) + parser.add_argument( + "-s", + "--size", + type=int, + default=200, + help="Number of samples to be used for plotting", + ) + parser.add_argument( + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", + ) + parser.add_argument( + "-d", "--data", type=Path, required=True, help="Path to patient data." + ) + parser.add_argument( + "-mp", + "--model_params", + type=Path, + required=True, + help="File path of mixture coefficients", + ) + + parser.set_defaults(run_main=main) + + +def multiple_plotter(dataset, risk_dictionary_extended, subsite, stage=""): + hist_cycl = cycler(histtype=["stepfilled", "step"]) * cycler( + color=list(COLORS.values()) + ) + line_cycl = cycler(linestyle=["-", "--"]) * cycler(color=list(COLORS.values())) + dataset_staging = dataset.copy() + dataset_staging["tumor", "1", "t_stage"] = dataset_staging[ + "tumor", "1", "t_stage" + ].replace([0, 1, 2], "early") + dataset_staging["tumor", "1", "t_stage"] = dataset_staging[ + "tumor", "1", "t_stage" + ].replace([3, 4], "late") + if stage == "early" or stage == "late": + data_selected = dataset_staging.loc[ + (dataset_staging["tumor"]["1"]["subsite"] == subsite) + & (dataset_staging["tumor"]["1"]["t_stage"] == stage) + ] + else: + data_selected = dataset_staging.loc[ + dataset_staging["tumor"]["1"]["subsite"] == subsite + ] + min_value = 0 + max_value = 100 + prevalence = {} + number_of_patients = {} + risks = {} + risk_dictionary = risk_dictionary_extended[subsite] + for key in risk_dictionary.keys(): + prevalence[key] = (data_selected["max_llh"]["ipsi"][key] == True).sum() + number_of_patients[key] = data_selected["max_llh"]["ipsi"][key].notna().sum() + risks[key] = np.array(risk_dictionary[key]) * 100 + + num_matches = [prevalence[key] for key in risk_dictionary.keys()] + num_totals = [number_of_patients[key] for key in risk_dictionary.keys()] + values = [risks[key] for key in risk_dictionary.keys()] + hist_kwargs = { + "bins": np.linspace(min_value, max_value, 80), + "density": True, + "alpha": 0.6, + "linewidth": 2.0, + } + fig, ax = plt.subplots(figsize=(12, 4)) + + x = np.linspace(min_value, max_value, 200) + zipper = zip( + values, + risk_dictionary.keys(), + num_matches, + num_totals, + hist_cycl, + line_cycl, + strict=False, + ) + for vals, label, a, n, hstyle, lstyle in zipper: + ax.hist(vals, label=label, **hist_kwargs, **hstyle) + if not np.isnan(a): + post = sp.stats.beta.pdf(x / 100.0, a + 1, n - a + 1) / 100.0 + ax.plot(x, post, label=f"{int(a)}/{int(n)}", **lstyle) + ax.legend() + ax.set_xlabel("probability [%]") + fig.suptitle(f"Risk distributions {stage} for subsite {subsite}", fontsize=16) + return fig + + +def multiple_plotter_component(risk_dictionary_extended, component, stage=""): + hist_cycl = cycler(histtype=["stepfilled", "step"]) * cycler( + color=list(COLORS.values()) + ) + line_cycl = cycler(linestyle=["-", "--"]) * cycler(color=list(COLORS.values())) + min_value = 0 + max_value = 100 + prevalence = {} + number_of_patients = {} + risks = {} + risk_dictionary = risk_dictionary_extended[component] + for key in risk_dictionary.keys(): + risks[key] = np.array(risk_dictionary[key]) * 100 + + values = [risks[key] for key in risk_dictionary.keys()] + hist_kwargs = { + "bins": np.linspace(min_value, max_value, 80), + "density": True, + "alpha": 0.6, + "linewidth": 2.0, + } + fig, ax = plt.subplots(figsize=(12, 4)) + + x = np.linspace(min_value, max_value, 200) + zipper = zip(values, risk_dictionary.keys(), hist_cycl, strict=False) + for vals, label, hstyle in zipper: + ax.hist(vals, label=label, **hist_kwargs, **hstyle) + ax.legend() + ax.set_xlabel("probability [%]") + fig.suptitle(f"Risk distributions {stage} for component {component}", fontsize=16) + return fig + + +def main(args: argparse.Namespace): + params = load_yaml_params(args.params) + inference_data = load_patient_data(args.data) + backend = emcee.backends.HDFBackend(args.input) + samples = backend.get_chain(flat=True) + model_params = pd.read_csv(args.model_params, header=[0]) + + # ugly, but necessary for pickling + global MIXTURE + MIXTURE = create_mixture(params) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE.components[0], models.Unilateral): + side = params["model"].get("side", "ipsi") + MIXTURE.load_patient_data( + inference_data, + split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), + mapping=mapping, + ) + assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) + + else: + raise "Only Unilateral has been implemented so far" + + param_dict = dict(model_params.iloc[-1]) + MIXTURE.set_params(**param_dict) + MIXTURE.set_resps(expectation(MIXTURE, param_dict)) + + print(MIXTURE.get_params()) + lnls = list(MIXTURE.components[0].graph.lnls.keys()) + component_list = list(range(len(MIXTURE.components))) + component_dictionary_early_full_sampling = {} + component_dictionary_late_full_sampling = {} + + for component in component_list: + component_dictionary_early_full_sampling[str(component)] = { + lnl: [] for lnl in lnls + } + component_dictionary_late_full_sampling[str(component)] = { + lnl: [] for lnl in lnls + } + + subsite_dictionary_early_full_sampling = {} + subsite_dictionary_late_full_sampling = {} + subsite_list = list(MIXTURE.subgroups.keys()) + + for subsite in subsite_list: + subsite_dictionary_early_full_sampling[subsite] = { + lnl: [] + for lnl in lnls # Dynamically generate keys from the lnls list + } + subsite_dictionary_late_full_sampling[subsite] = {lnl: [] for lnl in lnls} + + involvement_dict = {lnl: {lnl: True} for lnl in lnls} + samples_thinned = samples[:: int(np.round(len(samples) / args.size, 0))] + for round, sample in enumerate(samples_thinned): + if args.mode == "fixed_latent": + MIXTURE.set_params(*sample) + elif args.mode == "fixed_mixture": + _set_params(MIXTURE, sample) + else: + raise ValueError("Invalid mode") + component_dictionary_early_full_sampling["0"]["II"].append( + MIXTURE.components[0].risk( + involvement=involvement_dict["II"], t_stage="early" + ) + ) + for component in component_list: + for lnl in lnls: + component_dictionary_early_full_sampling[str(component)][lnl].append( + MIXTURE.components[component].risk( + involvement=involvement_dict[lnl], t_stage="early" + ) + ) + component_dictionary_late_full_sampling[str(component)][lnl].append( + MIXTURE.components[component].risk( + involvement=involvement_dict[lnl], t_stage="late" + ) + ) + + for subsite in subsite_list: + for lnl in lnls: + subsite_dictionary_early_full_sampling[subsite][lnl].append( + MIXTURE.risk( + subgroup=subsite, + involvement=involvement_dict[lnl], + t_stage="early", + ) + ) + subsite_dictionary_late_full_sampling[subsite][lnl].append( + MIXTURE.risk( + subgroup=subsite, + involvement=involvement_dict[lnl], + t_stage="late", + ) + ) + print(round, " done") + + for component in component_list: + fig = multiple_plotter_component( + component_dictionary_early_full_sampling, str(component), stage="early" + ) + save_figure( + args.output / f"component_{component}_early", fig, formats=["png", "svg"] + ) + fig = multiple_plotter_component( + component_dictionary_late_full_sampling, str(component), stage="late" + ) + save_figure( + args.output / f"component_{component}_late", fig, formats=["png", "svg"] + ) + for subsite in subsite_list: + fig = multiple_plotter( + inference_data, + subsite_dictionary_early_full_sampling, + subsite, + stage="early", + ) + save_figure(args.output / f"{subsite}_early", fig, formats=["png", "svg"]) + fig = multiple_plotter( + inference_data, subsite_dictionary_late_full_sampling, subsite, stage="late" + ) + save_figure(args.output / f"{subsite}_late", fig, formats=["png", "svg"]) + + # Save dictionary to a JSON file + with open("subsite_early.json", "w") as file: + json.dump(subsite_dictionary_early_full_sampling, file) + with open("subsite_late.json", "w") as file: + json.dump(subsite_dictionary_late_full_sampling, file) + with open("component_early.json", "w") as file: + json.dump(component_dictionary_early_full_sampling, file) + with open("component_late.json", "w") as file: + json.dump(component_dictionary_late_full_sampling, file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/plot/simplex_plot.py b/lyscripts/plot/simplex_plot.py new file mode 100644 index 0000000..a67b5f2 --- /dev/null +++ b/lyscripts/plot/simplex_plot.py @@ -0,0 +1,323 @@ +"""Visualize the component assignments of the trained mixture model.""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.lines import Line2D +from matplotlib.ticker import StrMethodFormatter + +from lyscripts.plot.utils import ( + COLORS, + SUBSITE_COLORS, + add_perpendicular_ticks, + save_figure, +) +from lyscripts.utils import load_yaml_params + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument( + "--input", type=Path, help="File path of the mixture coefficients" + ) + parser.add_argument("--output", type=Path, help="Output path for the plot") + parser.add_argument( + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", + ) + parser.set_defaults(run_main=main) + + +def plot_2d_simplex(mixture_df, data, output, component_names=False): + data = pd.read_csv(data, header=[0, 1, 2]) + subsites = list(mixture_df.columns) + colors_ordered = [SUBSITE_COLORS[subsite] for subsite in subsites] + num_patients = dict(data["tumor", "1", "subsite"].value_counts()) + annotations = [f"{key}\n({num_patients[key]})" for key in mixture_df.columns] + sizes = [num_patients[key] for key in mixture_df.columns] + + fig, bottom_ax = plt.subplots(figsize=(8, 2.5)) + bottom_ax.scatter( + mixture_df.loc[0], + np.zeros(len(mixture_df.loc[0])), + s=sizes, + c=colors_ordered, + alpha=0.7, + linewidths=0.0, + zorder=10, + ) + + # Sort annotations by x-position (left to right) + x_positions = mixture_df.loc[0].values + sorted_indices = np.argsort(x_positions) + + for i, idx in enumerate(sorted_indices): + x = x_positions[idx] + num = list(num_patients.values())[idx] + annotation = annotations[idx] + + offset = 0.025 * (-1) ** i + bottom_ax.annotate( + annotation, + xy=(x, np.sqrt(0.0000003 * num) * (-1) ** i), + xytext=(x, offset), + ha="center", + va="bottom" if i % 2 == 0 else "top", + fontsize="small", + arrowprops={ + "arrowstyle": "-", + "color": COLORS["gray"], + "linewidth": 1.0, + }, + ) + + # X-axis formatting + bottom_ax.set_xlabel("assignment to component 1") + bottom_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) + + # Add secondary top axis + top_ax = bottom_ax.secondary_xaxis( + "top", functions=(lambda x: 1.0 - x, lambda x: 1.0 - x) + ) + top_ax.set_xlabel("assignment to component 2") + top_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) + + # Clean up + bottom_ax.set_yticks([]) + bottom_ax.grid(axis="x", alpha=0.5, color=COLORS["gray"], linestyle=":") + + # Adjust layout to prevent clipping of labels + plt.tight_layout() + + save_figure(output, fig, formats=["png", "svg"]) + logger.info("Simplex plot saved") + + +def plot_3d_simplex(mixture_df, data, output, component_names=False): + data = pd.read_csv(data, header=[0, 1, 2]) + subsites = list(mixture_df.columns) + colors_ordered = [SUBSITE_COLORS[subsite] for subsite in subsites] + + # Define the plane's normal vector + normal_vector = np.array([1, 1, 1]) / np.sqrt(3) + + v1 = np.array([1, -1, 0]) / np.sqrt(2) + + # Calculate the second orthogonal vector using the cross product + v2 = np.cross(normal_vector, v1) * -1 + + # Project the point onto the new coordinate system + origin = np.array([0, 1, 0]) + x_origin = origin @ v1 + y_origin = origin @ v2 + + x_vals = mixture_df.T @ v1 - x_origin + y_vals = mixture_df.T @ v2 - y_origin + + extremes = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + extremes_x = extremes @ v1 - x_origin + extremes_y = extremes @ v2 - y_origin + + # Plot the point in 2D + import matplotlib.pyplot as plt + + odered_value_counts = data["tumor"]["1"]["subsite"].value_counts()[subsites] + sizes = odered_value_counts * 3 + + fig, ax = plt.subplots(figsize=(8, 6.8)) + + # Plot the points with varying sizes and colors + for i in range(len(x_vals)): + ax.scatter( + x_vals[i], y_vals[i], s=sizes[i], color=colors_ordered[i], label=subsites[i] + ) + ax.text(x_vals[i], y_vals[i], subsites[i], fontsize=9, ha="center", va="center") + + legend_text = [] + for index in range(len(subsites)): + legend_text.append( + subsites[index] + ", " + str(odered_value_counts[index]) + " patients" + ) + + legend_elements = [ + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor=color, + markersize=10, + label=subsite, + ) + for color, subsite in zip(colors_ordered, legend_text, strict=False) + ] + + # Add a legend with fixed dot sizes + ax.legend( + handles=legend_elements, loc="upper right", title="Subsites", fontsize="small" + ) + + # Connect the points + ax.plot(extremes_x, extremes_y, color="black", alpha=0.5) + + # Close the triangle by connecting the last point to the first + ax.plot( + [extremes_x[-1], extremes_x[0]], + [extremes_y[-1], extremes_y[0]], + color="black", + alpha=0.5, + ) + + # Calculate midpoints of each side of the triangle + midpoints_x = ( + (extremes_x[0] + extremes_x[1]) / 2, + (extremes_x[1] + extremes_x[2]) / 2, + (extremes_x[2] + extremes_x[0]) / 2, + ) + midpoints_y = ( + (extremes_y[0] + extremes_y[1]) / 2, + (extremes_y[1] + extremes_y[2]) / 2, + (extremes_y[2] + extremes_y[0]) / 2, + ) + + # Draw lines from each vertex to the midpoint of the opposite side + ax.plot( + [extremes_x[0], midpoints_x[1]], + [extremes_y[0], midpoints_y[1]], + color="gray", + linestyle="--", + linewidth=1, + ) + ax.plot( + [extremes_x[1], midpoints_x[2]], + [extremes_y[1], midpoints_y[2]], + color="gray", + linestyle="--", + linewidth=1, + ) + ax.plot( + [extremes_x[2], midpoints_x[0]], + [extremes_y[2], midpoints_y[0]], + color="gray", + linestyle="--", + linewidth=1, + ) + + # Add perpendicular ticks to each line with adjusted length + add_perpendicular_ticks( + extremes_x[0], extremes_y[0], midpoints_x[1], midpoints_y[1] + ) + add_perpendicular_ticks( + extremes_x[1], extremes_y[1], midpoints_x[2], midpoints_y[2] + ) + add_perpendicular_ticks( + extremes_x[2], extremes_y[2], midpoints_x[0], midpoints_y[0] + ) + + # Scaling factor to move the text farther from the vertices + scaling_factor = 1.1 + + # Calculate the centroid of the triangle + centroid_x = np.mean(extremes_x) + centroid_y = np.mean(extremes_y) + + # Scale the label positions away from the centroid + scaled_extremes_x = centroid_x + scaling_factor * (extremes_x - centroid_x) + scaled_extremes_y = centroid_y + scaling_factor * (extremes_y - centroid_y) + + if component_names: + # Plot the text labels farther away from the triangle + if "C32.0" in subsites: + component_larynx = mixture_df["C32.0"].argmax() + ax.text( + scaled_extremes_x[component_larynx], + scaled_extremes_y[component_larynx], + "Larynx like", + fontsize=10, + ha="right", + va="top", + c=COLORS["orange"], + ) + + if "C01" in subsites: + component_oropharynx = mixture_df["C01"].argmax() + ax.text( + scaled_extremes_x[component_oropharynx], + scaled_extremes_y[component_oropharynx], + "Oropharynx like", + fontsize=10, + ha="left", + va="top", + c=COLORS["green"], + ) + + if "C03" in subsites: + component_oral_cavity = mixture_df["C03"].argmax() + ax.text( + scaled_extremes_x[component_oral_cavity], + scaled_extremes_y[component_oral_cavity], + "Oral cavity like", + fontsize=10, + ha="left", + va="top", + c=COLORS["blue"], + ) + + if "C13" in subsites: + component_hypopharynx = mixture_df["C13"].argmax() + ax.text( + scaled_extremes_x[component_hypopharynx], + scaled_extremes_y[component_hypopharynx], + "Hypopharynx like", + fontsize=10, + ha="left", + va="top", + c=COLORS["red"], + ) + + save_figure(output, fig, formats=["png", "svg"]) + logger.info("Simplex plot saved") + + +def main(args: argparse.Namespace): + mixture_df = pd.read_csv(args.input) + nr_components = len(mixture_df) + params = load_yaml_params(args.params) + data = params["general"]["data"] + if nr_components == 2: + plot_2d_simplex(mixture_df, data, output=args.output) + elif nr_components == 3: + plot_3d_simplex(mixture_df, data, output=args.output) + else: + logger.info(f"Simplex not supported for {nr_components} components") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/plot/thermo_int.py b/lyscripts/plot/thermo_int.py index e20f507..208840d 100644 --- a/lyscripts/plot/thermo_int.py +++ b/lyscripts/plot/thermo_int.py @@ -1,9 +1,9 @@ -""" -Plot how the accuracy develops over the course of a thermodynamic integration run. +"""Plot how the accuracy develops over the course of a thermodynamic integration run. This can also be used to compare how the accuracy of different models develops during thermdynamic integration. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -43,35 +43,39 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "inputs", type=Path, nargs="+", - help="Paths to the CSV files containing the stored TI runs" + "inputs", + type=Path, + nargs="+", + help="Paths to the CSV files containing the stored TI runs", ) group = parser.add_mutually_exclusive_group() group.add_argument( - "-o", "--output", type=Path, - help="Path to where the plot should be stored (PNG and SVG)" + "-o", + "--output", + type=Path, + help="Path to where the plot should be stored (PNG and SVG)", ) group.add_argument( - "--show", action="store_true", - help="Show the plot instead of saving it" + "--show", action="store_true", help="Show the plot instead of saving it" ) + parser.add_argument("--title", default=None, help="Title of the plot") parser.add_argument( - "--title", default=None, - help="Title of the plot" - ) - parser.add_argument( - "--labels", type=str, nargs="+", default=[], - help="Labels for the individual data series" + "--labels", + type=str, + nargs="+", + default=[], + help="Labels for the individual data series", ) parser.add_argument( - "--power", default=5., type=float, - help="Scale the x-axis with this power" + "--power", default=5.0, type=float, help="Scale the x-axis with this power" ) parser.add_argument( - "--mplstyle", default="./.mplstyle", type=Path, - help="Path to the MPL stylesheet" + "--mplstyle", + default="./.mplstyle", + type=Path, + help="Path to the MPL stylesheet", ) parser.set_defaults(run_main=main) @@ -92,27 +96,25 @@ def main(args: argparse.Namespace): logger.info(f"+ read in {input}") logger.info("Loaded CSV file(s)") - fig, ax = plt.subplots(figsize=get_size()) if args.title is not None: fig.suptitle(args.title) ax.set_xlabel("inverse temperature $\\beta$") - xticks = np.linspace(0., 1., 7) + xticks = np.linspace(0.0, 1.0, 7) xticklabels = [f"{x**args.power:.2g}" for x in xticks] ax.set_xticks(ticks=xticks, labels=xticklabels) - ax.set_xlim(left=0., right=1.) + ax.set_xlim(left=0.0, right=1.0) ax.set_ylabel("accuracy $\\mathcal{A}(\\beta)$") ax.set_yscale("symlog") ax.get_yaxis().set_major_locator(matplotlib.ticker.MultipleLocator(800)) ax.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) - ax.ticklabel_format(axis="y", style="sci", scilimits=(2,2)) + ax.ticklabel_format(axis="y", style="sci", scilimits=(2, 2)) logger.info("Prepared figure") - for i, series in enumerate(accuracy_series): - last_acc = series['accuracy'].values[-1] + last_acc = series["accuracy"].values[-1] try: label = args.labels[i] + " $\\mathcal{A}(1)$ = " + f"{last_acc:g}" except IndexError: @@ -120,14 +122,14 @@ def main(args: argparse.Namespace): if "stddev" in series: ax.errorbar( - series["β"]**(1./args.power), + series["β"] ** (1.0 / args.power), series["accuracy"], yerr=series["stddev"], label=label, ) else: ax.plot( - series["β"]**(1./args.power), + series["β"] ** (1.0 / args.power), series["accuracy"], label=label, ) @@ -136,7 +138,6 @@ def main(args: argparse.Namespace): ax.legend() logger.info("Plotted series") - if args.show: plt.show() logger.info("Showed the plot") diff --git a/lyscripts/plot/utils.py b/lyscripts/plot/utils.py index acf78f6..f0d1da9 100644 --- a/lyscripts/plot/utils.py +++ b/lyscripts/plot/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math from abc import abstractmethod from collections.abc import Mapping from dataclasses import dataclass, field @@ -13,6 +14,7 @@ import matplotlib.pyplot as plt import numpy as np import scipy as sp +from matplotlib.colors import LinearSegmentedColormap from lyscripts.decorators import ( check_input_file_exists, @@ -31,9 +33,38 @@ "green": "#00afa5", "red": "#ae0060", "gray": "#c5d5db", + "light_blue": "#00A8D8", + "dark_grey_experimental": "#404756", } COLOR_CYCLE = cycle(COLORS.values()) CM_PER_INCH = 2.54 +blue_to_white = LinearSegmentedColormap.from_list( + "blue to white", [COLORS["blue"], "#ffffff"], N=256 +) +green_to_white = LinearSegmentedColormap.from_list( + "green_to_white", [COLORS["green"], "#ffffff"], N=256 +) +red_to_white = LinearSegmentedColormap.from_list( + "red_to_white", [COLORS["red"], "#ffffff"], N=256 +) +orange_to_white = LinearSegmentedColormap.from_list( + "orange_to_white", [COLORS["orange"], "#ffffff"], N=256 +) +SUBSITE_COLORS = { + "C03": blue_to_white(0), + "C04": blue_to_white(0.15), + "C06": blue_to_white(0.3), + "C02": blue_to_white(0.45), + "C05": blue_to_white(0.6), + "C10": green_to_white(0), + "C09": green_to_white(0.3), + "C01": green_to_white(0.6), + "C12": red_to_white(0), + "C13": red_to_white(0.5), + "C32.1": orange_to_white(0), + "C32.2": orange_to_white(0.3), + "C32.0": orange_to_white(0.6), +} def floor_at_decimal(value: float, decimal: int) -> float: @@ -78,6 +109,7 @@ def clean_and_check(filename: str | Path) -> Path: AbstractDistributionT = TypeVar("AbstractDistributionT", bound="AbstractDistribution") + @dataclass(kw_only=True) class AbstractDistribution: """Abstract class for distributions that should be plotted.""" @@ -407,3 +439,109 @@ def save_figure( """Save a ``figure`` to ``output_path`` in every one of the provided ``formats``.""" for frmt in formats: figure.savefig(output_path.with_suffix(f".{frmt}")) + + +def p_to_xyz(p): + # Project 4D representation down to 3D representation + s3 = 1 / math.sqrt(3.0) + s6 = 1 / math.sqrt(6.0) + x = -1 * p[0] + 1 * p[1] + 0 * p[2] + 0 * p[3] + y = -s3 * p[0] - s3 * p[1] + 2 * s3 * p[2] + 0 * p[3] + z = -s3 * p[0] - s3 * p[1] - s3 * p[2] + 3 * s6 * p[3] + return x, y, z + + +def add_perpendicular_crosses_3d(ax, x1, y1, z1, x2, y2, z2, tick_length=0.03): + num_ticks = 6 # Number of ticks + for i in range(num_ticks): + t = i / (num_ticks - 1) + + # Interpolation point (x, y, z) on the line + x_tick = x1 + t * (x2 - x1) + y_tick = y1 + t * (y2 - y1) + z_tick = z1 + t * (z2 - z1) + + # Vector along the line (direction of the line) + line_vec = np.array([x2 - x1, y2 - y1, z2 - z1]) + + # First perpendicular vector (cross product with z-axis) + perp_vec1 = np.cross(line_vec, [0, 0, 1]) + length1 = np.linalg.norm(perp_vec1) + if ( + length1 == 0 + ): # Prevent division by zero (in rare cases when parallel to z-axis) + perp_vec1 = np.cross(line_vec, [1, 0, 0]) # Cross with x-axis instead + length1 = np.linalg.norm(perp_vec1) + perp_vec1 /= length1 # Normalize + + # Second perpendicular vector (cross product with line_vec and perp_vec1) + perp_vec2 = np.cross(line_vec, perp_vec1) + perp_vec2 /= np.linalg.norm(perp_vec2) # Normalize + + # Scale the perpendicular vectors by tick length + perp_vec1 *= tick_length + perp_vec2 *= tick_length + + # Draw the cross (two perpendicular lines) + ax.plot( + [x_tick - perp_vec1[0], x_tick + perp_vec1[0]], + [y_tick - perp_vec1[1], y_tick + perp_vec1[1]], + [z_tick - perp_vec1[2], z_tick + perp_vec1[2]], + color="gray", + linewidth=0.8, + ) + + ax.plot( + [x_tick - perp_vec2[0], x_tick + perp_vec2[0]], + [y_tick - perp_vec2[1], y_tick + perp_vec2[1]], + [z_tick - perp_vec2[2], z_tick + perp_vec2[2]], + color="gray", + linewidth=0.8, + ) + + ax.text( + x_tick, + y_tick, + z_tick, + f"{int(100 - t * 100)}%", + fontsize=6, + ha="right", + va="bottom", + ) + + +def add_perpendicular_ticks(x1, y1, x2, y2, tick_length=0.01): + num_ticks = 6 # Number of ticks including 0% and 100% + for i in range(num_ticks): + t = i / (num_ticks - 1) + x_tick = x1 + t * (x2 - x1) + y_tick = y1 + t * (y2 - y1) + + # Vector along the line + dx = x2 - x1 + dy = y2 - y1 + + # Perpendicular vector + perp_dx = -dy + perp_dy = dx + + # Normalize the perpendicular vector + length = np.sqrt(perp_dx**2 + perp_dy**2) + perp_dx /= length + perp_dy /= length + + # Draw tick as a short perpendicular line + plt.plot( + [x_tick - tick_length * perp_dx, x_tick + tick_length * perp_dx], + [y_tick - tick_length * perp_dy, y_tick + tick_length * perp_dy], + color="gray", + linewidth=0.8, + ) + plt.text( + x_tick, + y_tick, + f"{int(100 - t * 100)}%", + fontsize=8, + ha="right", + va="bottom", + ) diff --git a/lyscripts/sample.py b/lyscripts/sample.py index 939ce13..ce2a77e 100644 --- a/lyscripts/sample.py +++ b/lyscripts/sample.py @@ -1,5 +1,4 @@ -""" -Learn the spread probabilities of the HMM for lymphatic tumor progression using +"""Learn the spread probabilities of the HMM for lymphatic tumor progression using the preprocessed data as input and MCMC as sampling method. This is the central script performing for our project on modelling lymphatic spread @@ -9,6 +8,7 @@ objective decisions with respect to defining the *elective clinical target volume* (CTV-N) in radiotherapy. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -58,60 +58,87 @@ def _add_arguments(parser: argparse.ArgumentParser): This is called by the parent module that is called via the command line. """ parser.add_argument( - "-i", "--input", type=Path, required=True, - help="Path to training data files" + "-i", "--input", type=Path, required=True, help="Path to training data files" ) parser.add_argument( - "-o", "--output", type=Path, required=True, - help="Path to the HDF5 file to store the results in" + "-o", + "--output", + type=Path, + required=True, + help="Path to the HDF5 file to store the results in", ) parser.add_argument( - "--history", type=Path, nargs="?", - help="Path to store the burnin history in (as CSV file)." + "--history", + type=Path, + nargs="?", + help="Path to store the burnin history in (as CSV file).", ) parser.add_argument( - "-w", "--walkers-per-dim", type=int, default=10, + "-w", + "--walkers-per-dim", + type=int, + default=10, help="Number of walkers per dimension", ) parser.add_argument( - "-b", "--burnin", type=int, nargs="?", - help="Number of burnin steps. If not provided, sampler runs until convergence." + "-b", + "--burnin", + type=int, + nargs="?", + help="Number of burnin steps. If not provided, sampler runs until convergence.", ) parser.add_argument( - "--check-interval", type=int, default=100, - help="Check convergence every `check_interval` steps." + "--check-interval", + type=int, + default=100, + help="Check convergence every `check_interval` steps.", ) parser.add_argument( - "--trust-fac", type=float, default=50., - help="Factor to trust the autocorrelation time for convergence." + "--trust-fac", + type=float, + default=50.0, + help="Factor to trust the autocorrelation time for convergence.", ) parser.add_argument( - "--rel-thresh", type=float, default=0.05, - help="Relative threshold for convergence." + "--rel-thresh", + type=float, + default=0.05, + help="Relative threshold for convergence.", ) parser.add_argument( - "-n", "--nsteps", type=int, default=100, - help="Number of MCMC samples to draw, irrespective of thinning." + "-n", + "--nsteps", + type=int, + default=100, + help="Number of MCMC samples to draw, irrespective of thinning.", ) parser.add_argument( - "-t", "--thin", type=int, default=10, - help="Thinning factor for the MCMC chain." + "-t", "--thin", type=int, default=10, help="Thinning factor for the MCMC chain." ) parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file." + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", ) parser.add_argument( - "-c", "--cores", type=int, nargs="?", + "-c", + "--cores", + type=int, + nargs="?", help=( "Number of parallel workers (CPU cores/threads) to use. If not provided, " "it will use all cores. If set to zero, multiprocessing will not be used." - ) + ), ) parser.add_argument( - "-s", "--seed", type=int, default=42, - help="Seed value to reproduce the same sampling round." + "-s", + "--seed", + type=int, + default=42, + help="Seed value to reproduce the same sampling round.", ) parser.set_defaults(run_main=main) @@ -119,8 +146,9 @@ def _add_arguments(parser: argparse.ArgumentParser): MODEL = None + def log_prob_fn(theta: np.array) -> float: - """log probability function using global variables because of pickling.""" + """Log probability function using global variables because of pickling.""" return MODEL.likelihood(given_params=theta) @@ -183,11 +211,12 @@ def run_burnin( progress.update(task, advance=1) new_acor_time = sampler.get_autocorr_time(tol=0).mean() - old_acor_time = history.acor_times[-1] if len(history.acor_times) > 0 else np.inf + old_acor_time = ( + history.acor_times[-1] if len(history.acor_times) > 0 else np.inf + ) - new_accept_frac = ( - (np.sum(sampler.backend.accepted) - num_accepted) - / (sampler.nwalkers * check_interval) + new_accept_frac = (np.sum(sampler.backend.accepted) - num_accepted) / ( + sampler.nwalkers * check_interval ) num_accepted = np.sum(sampler.backend.accepted) @@ -198,7 +227,9 @@ def run_burnin( is_converged = burnin is None is_converged &= new_acor_time * trust_fac < sampler.iteration - is_converged &= np.abs(new_acor_time - old_acor_time) / new_acor_time < rel_thresh + is_converged &= ( + np.abs(new_acor_time - old_acor_time) / new_acor_time < rel_thresh + ) if is_converged: break @@ -236,6 +267,7 @@ def run_sampling( class DummyPool: """Dummy class to allow for no multiprocessing.""" + def __enter__(self): return None @@ -277,7 +309,9 @@ def main(args: argparse.Namespace) -> None: with real_or_dummy_pool as pool: sampler = emcee.EnsembleSampler( - nwalkers, ndim, log_prob_fn, + nwalkers, + ndim, + log_prob_fn, moves=moves_mix, backend=hdf5_backend, pool=pool, diff --git a/lyscripts/scenario.py b/lyscripts/scenario.py index f16a370..9955484 100644 --- a/lyscripts/scenario.py +++ b/lyscripts/scenario.py @@ -1,5 +1,4 @@ -""" -This module implements helpers and classes that help us deal with what we call a +"""This module implements helpers and classes that help us deal with what we call a *scenario*. A scenario is a set of parameters that determine how we compute priors, posteriors, prevalences, and risks. @@ -9,6 +8,7 @@ model) are relevant. But e.g. posteriors and risks also require us to provide a diagnosis, given which to compute the quantities of interest. """ + import argparse import hashlib import inspect @@ -33,8 +33,10 @@ class UninitializedProperty(Exception): in the getter when no private attribute is found. """ + ScenarioT = TypeVar("ScenarioT", bound="Scenario") + @dataclass class Scenario: """Dataclass for storing configuration of a scenario. @@ -52,7 +54,6 @@ class Scenario: is_uni: bool = False side: str = "ipsi" - @staticmethod def _defaults(property_name: str) -> Any: """Return the default value for a property. @@ -67,12 +68,11 @@ def _defaults(property_name: str) -> Any: {} """ return { - "t_stages_dist": np.array([1.]), + "t_stages_dist": np.array([1.0]), "involvement": {"ipsi": {}, "contra": {}}, "diagnosis": {"ipsi": {}, "contra": {}}, }[property_name] - def __post_init__(self) -> None: """Declate default value of properties. @@ -89,12 +89,11 @@ def __post_init__(self) -> None: if not self.is_uni: for side in ["ipsi", "contra"]: - if not side in self.diagnosis: + if side not in self.diagnosis: self.diagnosis[side] = {} - if not side in self.involvement: + if side not in self.involvement: self.involvement[side] = {} - @classmethod def fields(cls) -> dict[str, Any]: """Return a list of fields that may make up a scenario.""" @@ -132,11 +131,11 @@ def t_stages_dist(self) -> np.ndarray: self._t_stages_dist = self._defaults("t_stages_dist") if len(self._t_stages_dist) != len(self.t_stages): - new_x = np.linspace(0., 1., len(self.t_stages)) - old_x = np.linspace(0., 1., len(self._t_stages_dist)) + new_x = np.linspace(0.0, 1.0, len(self.t_stages)) + old_x = np.linspace(0.0, 1.0, len(self._t_stages_dist)) self._t_stages_dist = np.interp(new_x, old_x, self._t_stages_dist) - if not np.isclose(np.sum(self._t_stages_dist), 1.): + if not np.isclose(np.sum(self._t_stages_dist), 1.0): self._t_stages_dist /= np.sum(self._t_stages_dist) return np.array(self._t_stages_dist) @@ -146,7 +145,6 @@ def t_stages_dist(self, value: Iterable[float]) -> None: if not isinstance(value, property): self._t_stages_dist = value - @classmethod def from_namespace( cls, @@ -182,12 +180,16 @@ def from_namespace( scenario = cls(**kwargs) for side in ["ipsi", "contra"]: - pattern = getattr(namespace, f"{side}_involvement", None) or [None] * len(lnls) - tmp = {lnl: val for lnl, val in zip(lnls, pattern)} + pattern = getattr(namespace, f"{side}_involvement", None) or [None] * len( + lnls + ) + tmp = {lnl: val for lnl, val in zip(lnls, pattern, strict=False)} scenario._involvement[side] = tmp - pattern = getattr(namespace, f"{side}_diagnosis", None) or [None] * len(lnls) - tmp = {lnl: val for lnl, val in zip(lnls, pattern)} + pattern = getattr(namespace, f"{side}_diagnosis", None) or [None] * len( + lnls + ) + tmp = {lnl: val for lnl, val in zip(lnls, pattern, strict=False)} mod_name = getattr(namespace, "modality", "max_llh") scenario._diagnosis[side] = {mod_name: tmp} @@ -246,7 +248,6 @@ def list_from_params( return res - def as_dict( self, for_comp: Literal["priors", "posteriors", "prevalences", "risks"], @@ -260,21 +261,24 @@ def as_dict( if for_comp == "priors": return res - res.update({ - "midext": self.midext, - "diagnosis": self.diagnosis, - "side": self.side, - "is_uni": self.is_uni, - }) + res.update( + { + "midext": self.midext, + "diagnosis": self.diagnosis, + "side": self.side, + "is_uni": self.is_uni, + } + ) if for_comp == "risks": res["involvement"] = self.involvement return res - @property - def diagnosis(self) -> dict[str, dict[str, types.PatternType]] | dict[str, types.PatternType]: + def diagnosis( + self, + ) -> dict[str, dict[str, types.PatternType]] | dict[str, types.PatternType]: """Get bi- or unilateral diagosis, depending on attrs ``side`` and ``is_uni``.""" if not hasattr(self, "_diagnosis"): raise UninitializedProperty("diagnosis") @@ -289,7 +293,6 @@ def diagnosis(self, value: dict[str, dict[str, types.PatternType]]) -> None: if not isinstance(value, property): self._diagnosis = value - @property def involvement(self) -> dict[str, types.PatternType] | types.PatternType: """Get bi- or unilateral involvement, depending on attrs ``side`` and ``is_uni``.""" @@ -306,7 +309,6 @@ def involvement(self, value: dict[str, types.PatternType]) -> None: if not isinstance(value, property): self._involvement = value - def get_pattern( self, get_from: Literal["involvement", "diagnosis"], @@ -326,7 +328,6 @@ def get_pattern( return pattern - def md5_hash( self, for_comp: Literal["priors", "posteriors", "prevalences", "risks"], @@ -360,18 +361,24 @@ def add_scenario_arguments( {'mode': 'BN', 't_stages': ['early'], 't_stages_dist': array([1.])} """ parser.add_argument( - "--t-stages", nargs="+", default=["early"], + "--t-stages", + nargs="+", + default=["early"], help="T-stages to consider.", ) parser.add_argument( - "--t-stages-dist", nargs="+", type=float, + "--t-stages-dist", + nargs="+", + type=float, help=( "Distribution over T-stages. Prior distribution over hidden states will " "be marginalized over T-stages using this distribution." - ) + ), ) parser.add_argument( - "--mode", choices=["BN", "HMM"], default="HMM", + "--mode", + choices=["BN", "HMM"], + default="HMM", help="Mode to use for computing the scenario.", ) @@ -379,7 +386,9 @@ def add_scenario_arguments( return parser.add_argument( - "--midext", type=optional_bool, required=False, + "--midext", + type=optional_bool, + required=False, help=( "Use midline extention for computing the scenario. Only used with " "midline model." @@ -398,31 +407,41 @@ def add_scenario_arguments( if for_comp == "risks": parser.add_argument( - "--ipsi-involvement", nargs="+", type=optional_bool, + "--ipsi-involvement", + nargs="+", + type=optional_bool, help="Involvement to compute quantitty for (ipsilateral side).", ) parser.add_argument( - "--contra-involvement", nargs="+", type=optional_bool, + "--contra-involvement", + nargs="+", + type=optional_bool, help="Involvement to compute quantitty for (contralateral side).", ) if for_comp == "prevalences": parser.add_argument( - "--modality", default="max_llh", + "--modality", + default="max_llh", help="Modality name to compute predicted and observed prevalence for.", ) parser.add_argument( - "--ipsi-diagnosis", nargs="+", type=optional_bool, + "--ipsi-diagnosis", + nargs="+", + type=optional_bool, help="Diagnosis of ipsilateral side.", ) parser.add_argument( - "--contra-diagnosis", nargs="+", type=optional_bool, + "--contra-diagnosis", + nargs="+", + type=optional_bool, help="Diagnosis of contralateral side.", ) if __name__ == "__main__": - scenario = Scenario(t_stages=['a', 'b'], t_stages_dist=[0.2, 0.8]) + scenario = Scenario(t_stages=["a", "b"], t_stages_dist=[0.2, 0.8]) import doctest + doctest.testmod() diff --git a/lyscripts/temp_schedule.py b/lyscripts/temp_schedule.py index 09483e2..553c5b6 100644 --- a/lyscripts/temp_schedule.py +++ b/lyscripts/temp_schedule.py @@ -1,5 +1,4 @@ -""" -Generate inverse temperature schedules for thermodynamic integration using various +"""Generate inverse temperature schedules for thermodynamic integration using various different methods. Thermodynamic integration is quite sensitive to the specific schedule which is used. @@ -11,6 +10,7 @@ the interval $[0, 1]$ and then transform each point by computing $\\beta_i^k$ where $k$ could e.g. be 5. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -27,9 +27,7 @@ def _add_parser( subparsers: argparse._SubParsersAction, help_formatter, ): - """ - Add an `ArgumentParser` to the subparsers action. - """ + """Add an `ArgumentParser` to the subparsers action.""" parser = subparsers.add_parser( Path(__file__).name.replace(".py", ""), description=__doc__, @@ -40,21 +38,26 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): - """ - Add arguments needed to run this script to a `subparsers` instance + """Add arguments needed to run this script to a `subparsers` instance and run the respective main function when chosen. """ parser.add_argument( - "--method", choices=SCHEDULES.keys(), default=list(SCHEDULES.keys())[0], - help="Choose the method to distribute the inverse temperature." + "--method", + choices=SCHEDULES.keys(), + default=list(SCHEDULES.keys())[0], + help="Choose the method to distribute the inverse temperature.", ) parser.add_argument( - "--num", default=32, type=int, - help="Number of inverse temperatures in the schedule" + "--num", + default=32, + type=int, + help="Number of inverse temperatures in the schedule", ) parser.add_argument( - "--pow", default=4, type=float, - help="If a power schedule is chosen, use this as power" + "--pow", + default=4, + type=float, + help="If a power schedule is chosen, use this as power", ) parser.set_defaults(run_main=main) @@ -62,36 +65,39 @@ def _add_arguments(parser: argparse.ArgumentParser): def tolist(func: Callable) -> Callable: """Decorator to make sure the returned value is a list of floats.""" + def inner(*args) -> np.ndarray | list[float]: res = func(*args) if isinstance(res, np.ndarray): return res.tolist() return res + return inner @tolist def geometric_schedule(n: int, *_a) -> np.ndarray: """Create a geometric sequence of `n` numbers from 0. to 1.""" - log_seq = np.logspace(0., 1., n) - shifted_seq = log_seq - 1. - geom_seq = shifted_seq / 9. + log_seq = np.logspace(0.0, 1.0, n) + shifted_seq = log_seq - 1.0 + geom_seq = shifted_seq / 9.0 return geom_seq @tolist def linear_schedule(n: int, *_a) -> np.ndarray: """Create a linear sequence of `n` numbers from 0. to 1.""" - return np.linspace(0., 1., n) + return np.linspace(0.0, 1.0, n) @tolist def power_schedule(n: int, power: float, *_a) -> np.ndarray: """Create a power sequence of `n` numbers from 0. to 1.""" - lin_seq = np.linspace(0., 1., n) + lin_seq = np.linspace(0.0, 1.0, n) power_seq = lin_seq**power return power_seq + SCHEDULES = { "geometric": geometric_schedule, "linear": linear_schedule, diff --git a/lyscripts/utils.py b/lyscripts/utils.py index 57648ec..cd43f2e 100644 --- a/lyscripts/utils.py +++ b/lyscripts/utils.py @@ -1,11 +1,11 @@ -""" -This module contains frequently used functions and decorators that are used throughout +"""This module contains frequently used functions and decorators that are used throughout the subcommands to load e.g. YAML specifications or model definitions. It also contains helpers for reporting the script's progress via a slightly customized `rich` console and a custom `Exception` called `LyScriptsWarning` that can propagate occuring issues to the right place. """ + import warnings from logging import LogRecord from pathlib import Path @@ -16,6 +16,7 @@ import yaml from deprecated import deprecated from emcee.backends import HDFBackend +from lymixture import LymphMixture from lymph import diagnosis_times, models, types from rich.console import Console from rich.logging import RichHandler @@ -31,8 +32,10 @@ try: import streamlit from streamlit.runtime.scriptrunner import get_script_run_ctx + streamlit.status = streamlit.spinner except ImportError: + def get_script_run_ctx() -> bool: """A mock for the `get_script_run_ctx` function of `streamlit`.""" return None @@ -48,13 +51,13 @@ def get_script_run_ctx() -> bool: class LyScriptsWarning(Warning): - """ - Exception that can be raised by methods if they want the `LyScriptsReport` instance + """Exception that can be raised by methods if they want the `LyScriptsReport` instance to not stop and print a traceback, but display some message appropriately. Essentially, this is a way for decorated functions to propagate messages through the `report_state` decorator. """ + def __init__(self, *args: object, level: str = "info") -> None: """Extract the `level` of the message (can be "info", "warning" or "error").""" self.level = level @@ -69,7 +72,8 @@ def is_streamlit_running() -> bool: class CustomProgress(Progress): """Small wrapper around rich's `Progress` initializing my custom columns.""" - def __init__( self, **kwargs: dict): + + def __init__(self, **kwargs: dict): columns = [ SpinnerColumn(finished_text=CHECK), *Progress.get_default_columns(), @@ -80,6 +84,7 @@ def __init__( self, **kwargs: dict): class CustomRichHandler(RichHandler): """Uses `func_filepath` from the `extra` dict to modify `pathname`.""" + def emit(self, record: LogRecord) -> None: """Emit a log record.""" if ( @@ -99,11 +104,13 @@ def emit(self, record: LogRecord) -> None: def binom_pmf(support: list[int] | np.ndarray, p: float = 0.5): """Binomial PMF""" max_time = len(support) - 1 - if p > 1. or p < 0.: + if p > 1.0 or p < 0.0: raise ValueError("Binomial prob must be btw. 0 and 1") - q = 1. - p - binom_coeff = factorial(max_time) / (factorial(support) * factorial(max_time - support)) - return binom_coeff * p**support * q**(max_time - support) + q = 1.0 - p + binom_coeff = factorial(max_time) / ( + factorial(support) * factorial(max_time - support) + ) + return binom_coeff * p**support * q ** (max_time - support) FUNCS = { @@ -115,7 +122,7 @@ def graph_from_config(graph_params: dict) -> dict[tuple[str, str], list[str]]: """Build graph dictionary for the `lymph` models from the YAML params.""" lymph_graph = {} - if not "tumor" in graph_params and "lnl" in graph_params: + if "tumor" not in graph_params and "lnl" in graph_params: raise KeyError("Parameters must define tumors and LNLs") for node_type, node_dict in graph_params.items(): @@ -150,7 +157,7 @@ def _create_model_from_v0(params: dict[str, Any]) -> types.Model: if "model" in params: model_cls = getattr(models, params["model"]["class"]) - if not "is_symmetric" in params["model"]["kwargs"]: + if "is_symmetric" not in params["model"]["kwargs"]: warnings.warn( "The keywords `base_symmetric`, `trans_symmetric`, and `use_mixing` " "have been deprecated. Please use `is_symmetric` instead.", @@ -200,7 +207,7 @@ def assign_modalities( to the ``model``. Example: - + ------- >>> from_config = { ... "CT": {"spec": 0.76, "sens": 0.81}, ... "MRI": [0.63, 0.86, "pathological"], @@ -217,6 +224,7 @@ def assign_modalities( >>> assign_modalities(model, from_config, subset=["CT"]) >>> model.get_all_modalities() # doctest: +NORMALIZE_WHITESPACE {'CT': Clinical(spec=0.76, sens=0.81, is_trinary=False)} + """ if clear: model.clear_modalities() @@ -241,7 +249,7 @@ def create_distribution(config: dict[str, Any]) -> diagnosis_times.Distribution: kwargs = config.get("kwargs", {}) if (type_ := config.get("frozen")) is not None: - kwargs.update({"support": np.arange(max_time+1)}) + kwargs.update({"support": np.arange(max_time + 1)}) distribution = diagnosis_times.Distribution(FUNCS[type_](**kwargs)) elif (type_ := config.get("parametric")) is not None: distribution = diagnosis_times.Distribution(FUNCS[type_], max_time, **kwargs) @@ -274,7 +282,7 @@ def create_model(config: dict[str, Any], config_version: int = 0) -> types.Model model_kwargs = model_config.get("kwargs", {}) model = model_cls(graph_dict, **model_kwargs) - assign_modalities(model=model, config=config.get("modalities", {})) + assign_modalities(model=model, config=config.get("inference_modalities", {})) for t_stage, dist_config in model_config.get("distributions", {}).items(): distribution = create_distribution(dist_config) @@ -283,6 +291,48 @@ def create_model(config: dict[str, Any], config_version: int = 0) -> types.Model return model +@log_state() +def create_mixture(config: dict[str, Any], config_version: int = 0) -> types.Model: + """Create a model instance as defined by a ``config`` dictionary.""" + if version := config.get("version", config_version) < 0: + raise LyScriptsWarning(f"{version=} unsupported", level="error") + + if (graph_config := config.get("graph")) is None: + raise LyScriptsWarning("No graph definition found in YAML file", level="error") + + if (model_config := config.get("model")) is None: + raise LyScriptsWarning( + "No mixture definition found in YAML file", level="error" + ) + + graph_dict = graph_from_config(graph_config) + model_cls_name, _, cls_meth_name = model_config["class"].partition(".") + if model_cls_name != "Unilateral": + raise LyScriptsWarning( + "The mixture model has only been implemented for Unilateral so far", + level="error", + ) + model_cls = getattr(models, model_cls_name) + model_kwargs = model_config.get("kwargs", {}) + if not isinstance(model_kwargs, dict): + model_kwargs = {} + model_num_components = model_config.get("num_components") + model_kwargs["graph_dict"] = graph_dict + mixture = LymphMixture( + model_cls=model_cls, + model_kwargs=model_kwargs, + num_components=model_num_components, + ) + + # note: modalities can't be set here, as we need to add the data first to define the number of subgroups + + for t_stage, dist_config in model_config.get("distributions", {}).items(): + distribution = create_distribution(dist_config) + mixture.set_distribution(t_stage, distribution) + + return mixture + + def get_dict_depth(nested: dict) -> int: """Get the depth of a nested dictionary. @@ -404,7 +454,7 @@ def load_patient_data( ) -> pd.DataFrame: """Load patient data from a CSV file stored at ``file``.""" if header is None: - header = [0,1,2] + header = [0, 1, 2] return pd.read_csv(file_path, header=header) @@ -458,6 +508,7 @@ def initialize_backend( FalseChoices = Literal["false", "f", "no", "n", "healthy", "benign"] """Type alias for what is interpreted as healthy/benign involvement of an LNL.""" + def optional_bool(value: NoneChoices | TrueChoices | FalseChoices) -> bool | None: """Convert a string to a boolean or ``None``. @@ -481,4 +532,8 @@ def make_pattern( lnls: list[str], ) -> dict[str, bool | None]: """Create a dictionary from a list of bools and Nones.""" - return dict(zip(lnls, from_list or [None] * len(lnls))) + return dict(zip(lnls, from_list or [None] * len(lnls), strict=False)) + + +def to_numpy(params: dict[str, float]) -> np.ndarray: + return np.array([p for p in params.values()]) diff --git a/pyproject.toml b/pyproject.toml index 4b6ad6a..2667905 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ app = [ lyscripts = "lyscripts:main" [tool.setuptools.packages.find] -include = ["lyscripts"] +include = ["lyscripts","lyscripts.*"] [tool.setuptools_scm] write_to = "lyscripts/_version.py" diff --git a/tests/data/join_test.py b/tests/data/join_test.py index 879c96f..628cfa9 100644 --- a/tests/data/join_test.py +++ b/tests/data/join_test.py @@ -1,6 +1,5 @@ -""" -Test the correct joining of datasets. -""" +"""Test the correct joining of datasets.""" + from pathlib import Path from lyscripts.data.join import load_and_join_tables @@ -8,12 +7,11 @@ def test_load_and_join_tables(): """Test the correct joining of datasets.""" - input_paths = [Path("./tests/data/b.csv"), Path("./tests/data/a.csv")] joined = load_and_join_tables(input_paths) - assert joined.shape == (13,3), "Wrong concatenation shape." - assert joined['x','z','b'].isna().sum() == 6, "Wrong number of NaNs." - assert (joined['x','z','b'] == True).sum() == 4, "Wrong number of True values." - assert (joined['x','z','b'] == False).sum() == 3, "Wrong number of False values." - assert ('x', 'z', 'c') in joined.columns, "Column 'c' not found in joined table." + assert joined.shape == (13, 3), "Wrong concatenation shape." + assert joined["x", "z", "b"].isna().sum() == 6, "Wrong number of NaNs." + assert (joined["x", "z", "b"] == True).sum() == 4, "Wrong number of True values." + assert (joined["x", "z", "b"] == False).sum() == 3, "Wrong number of False values." + assert ("x", "z", "c") in joined.columns, "Column 'c' not found in joined table." diff --git a/tests/plot/plot_utils_test.py b/tests/plot/plot_utils_test.py index e1794f7..d5b350f 100644 --- a/tests/plot/plot_utils_test.py +++ b/tests/plot/plot_utils_test.py @@ -1,6 +1,5 @@ -""" -Testing of the utilities implemented for the plotting routines. -""" +"""Testing of the utilities implemented for the plotting routines.""" + from pathlib import Path import matplotlib.pyplot as plt @@ -21,20 +20,18 @@ @pytest.fixture def beta_samples(): - """ - Filename of an HDF5 file where some samples from a Beta distribution are stored - """ + """Filename of an HDF5 file where some samples from a Beta distribution are stored""" return "./tests/plot/data/beta_samples.hdf5" def test_floor_to_step(): """Check correct rounding down to a given step size.""" - numbers = np.array([0., 3., 7.4, 2.01, np.pi, 12.7, 12.7, 17.3 ]) - steps = np.array([2 , 2 , 5 , 2 , 3 , 3 , 5 , 0.17 ]) - exp_res = np.array([0., 2., 5. , 2. , 3. , 12. , 10. , 17.17]) + numbers = np.array([0.0, 3.0, 7.4, 2.01, np.pi, 12.7, 12.7, 17.3]) + steps = np.array([2, 2, 5, 2, 3, 3, 5, 0.17]) + exp_res = np.array([0.0, 2.0, 5.0, 2.0, 3.0, 12.0, 10.0, 17.17]) comp_res = np.zeros_like(exp_res) - for i, (num, step) in enumerate(zip(numbers, steps)): + for i, (num, step) in enumerate(zip(numbers, steps, strict=False)): comp_res[i] = floor_to_step(num, step) assert all(np.isclose(comp_res, exp_res)), "Floor to step did not work properly." @@ -42,12 +39,12 @@ def test_floor_to_step(): def test_ceil_to_step(): """Check correct rounding up to a given step size.""" - numbers = np.array([0., 3., 7.4, 2.01, np.pi, 12.7, 12.7, 17.3 ]) - steps = np.array([2 , 2 , 5 , 2 , 3 , 3 , 5 , 0.17 ]) - exp_res = np.array([2., 4., 10., 4. , 6. , 15. , 15. , 17.34]) + numbers = np.array([0.0, 3.0, 7.4, 2.01, np.pi, 12.7, 12.7, 17.3]) + steps = np.array([2, 2, 5, 2, 3, 3, 5, 0.17]) + exp_res = np.array([2.0, 4.0, 10.0, 4.0, 6.0, 15.0, 15.0, 17.34]) comp_res = np.zeros_like(exp_res) - for i, (num, step) in enumerate(zip(numbers, steps)): + for i, (num, step) in enumerate(zip(numbers, steps, strict=False)): comp_res[i] = ceil_to_step(num, step) assert all(np.isclose(comp_res, exp_res)), "Ceil to step did not work properly." @@ -64,30 +61,34 @@ def test_histogram_cls(beta_samples): hist_from_path = Histogram.from_hdf5( filename=path_filename, dataname="beta", - scale=10., + scale=10.0, label=custom_label, ) with pytest.raises(FileNotFoundError): Histogram.from_hdf5(filename=non_existent_filename, dataname="does_not_matter") - assert np.all(np.isclose(hist_from_str.values, 10. * hist_from_path.values)), ( - "Scaling of data does not work correclty" - ) - assert np.all(np.isclose( - hist_from_str.left_percentile(50.), - hist_from_str.right_percentile(50.), - )), "50% percentiles should be the same from the left and from the right." - assert np.all(np.isclose( - hist_from_path.left_percentile(10.), - hist_from_path.right_percentile(90.), - )), "10% from the left is not the same as 90% from the right" - assert hist_from_str.kwargs["label"] == "beta | mega scan | 100 | ext", ( - "Label extraction did not work" - ) - assert hist_from_path.kwargs["label"] == custom_label, ( - "Keyword override did not work" - ) + assert np.all( + np.isclose(hist_from_str.values, 10.0 * hist_from_path.values) + ), "Scaling of data does not work correclty" + assert np.all( + np.isclose( + hist_from_str.left_percentile(50.0), + hist_from_str.right_percentile(50.0), + ) + ), "50% percentiles should be the same from the left and from the right." + assert np.all( + np.isclose( + hist_from_path.left_percentile(10.0), + hist_from_path.right_percentile(90.0), + ) + ), "10% from the left is not the same as 90% from the right" + assert ( + hist_from_str.kwargs["label"] == "beta | mega scan | 100 | ext" + ), "Label extraction did not work" + assert ( + hist_from_path.kwargs["label"] == custom_label + ), "Keyword override did not work" def test_inverted_histogram_cls(beta_samples): @@ -100,28 +101,32 @@ def test_inverted_histogram_cls(beta_samples): hist_from_path = Histogram.from_hdf5( filename=path_filename, dataname="beta", - scale=-100., - offset=100., + scale=-100.0, + offset=100.0, label=custom_label, ) - assert np.all(np.isclose(100. - hist_from_str.values, hist_from_path.values)), ( - "Scaling and offsetting of data does not work correclty" - ) - assert np.all(np.isclose( - hist_from_str.left_percentile(50.), - hist_from_str.right_percentile(50.), - )), "50% percentiles should be the same from the left and from the right." - assert np.all(np.isclose( - hist_from_path.left_percentile(10.), - hist_from_path.right_percentile(90.), - )), "10% from the left is not the same as 90% from the right" - assert hist_from_str.kwargs["label"] == "beta | mega scan | 100 | ext", ( - "Label extraction did not work" - ) - assert hist_from_path.kwargs["label"] == custom_label, ( - "Keyword override did not work" - ) + assert np.all( + np.isclose(100.0 - hist_from_str.values, hist_from_path.values) + ), "Scaling and offsetting of data does not work correclty" + assert np.all( + np.isclose( + hist_from_str.left_percentile(50.0), + hist_from_str.right_percentile(50.0), + ) + ), "50% percentiles should be the same from the left and from the right." + assert np.all( + np.isclose( + hist_from_path.left_percentile(10.0), + hist_from_path.right_percentile(90.0), + ) + ), "10% from the left is not the same as 90% from the right" + assert ( + hist_from_str.kwargs["label"] == "beta | mega scan | 100 | ext" + ), "Label extraction did not work" + assert ( + hist_from_path.kwargs["label"] == custom_label + ), "Keyword override did not work" def test_posterior_cls(beta_samples): @@ -130,41 +135,49 @@ def test_posterior_cls(beta_samples): path_filename = Path(str_filename) non_existent_filename = "non_existent.hdf5" custom_label = "Lorem ipsum" - x_10 = np.linspace(0., 10., 100) - x_100 = np.linspace(0., 100., 100) + x_10 = np.linspace(0.0, 10.0, 100) + x_100 = np.linspace(0.0, 100.0, 100) post_from_str = BetaPosterior.from_hdf5(filename=str_filename, dataname="beta") post_from_path = BetaPosterior.from_hdf5( filename=path_filename, dataname="beta", - scale=10., + scale=10.0, label=custom_label, ) with pytest.raises(FileNotFoundError): - BetaPosterior.from_hdf5(filename=non_existent_filename, dataname="does_not_matter") - - assert post_from_str.num_success == post_from_path.num_success == 20, ( - "Number of successes not correctly extracted" - ) - assert post_from_str.num_total == post_from_path.num_total == 40, ( - "Total number of trials not correctly extracted" - ) - assert post_from_str.num_fail == post_from_path.num_fail == 20, ( - "Number of failures not correctly computed" - ) - assert np.all(np.isclose( - 10 * post_from_str.pdf(x_100), - post_from_path.pdf(x_10), - )), "PDFs with different scaling do not match" - assert np.all(np.isclose( - post_from_str.left_percentile(50.), - post_from_str.right_percentile(50.), - )), "50% percentiles should be the same from the left and from the right." - assert np.all(np.isclose( - post_from_path.left_percentile(10.), - post_from_path.right_percentile(90.), - )), "10% from the left is not the same as 90% from the right" + BetaPosterior.from_hdf5( + filename=non_existent_filename, dataname="does_not_matter" + ) + + assert ( + post_from_str.num_success == post_from_path.num_success == 20 + ), "Number of successes not correctly extracted" + assert ( + post_from_str.num_total == post_from_path.num_total == 40 + ), "Total number of trials not correctly extracted" + assert ( + post_from_str.num_fail == post_from_path.num_fail == 20 + ), "Number of failures not correctly computed" + assert np.all( + np.isclose( + 10 * post_from_str.pdf(x_100), + post_from_path.pdf(x_10), + ) + ), "PDFs with different scaling do not match" + assert np.all( + np.isclose( + post_from_str.left_percentile(50.0), + post_from_str.right_percentile(50.0), + ) + ), "50% percentiles should be the same from the left and from the right." + assert np.all( + np.isclose( + post_from_path.left_percentile(10.0), + post_from_path.right_percentile(90.0), + ) + ), "10% from the left is not the same as 90% from the right" @pytest.mark.mpl_image_compare @@ -175,7 +188,7 @@ def test_draw(beta_samples): hist = Histogram.from_hdf5(filename, dataname) post = BetaPosterior.from_hdf5(filename, dataname) fig, ax = plt.subplots() - ax = draw(axes=ax, contents=[hist, post], percent_lims=(2.,2.)) + ax = draw(axes=ax, contents=[hist, post], percent_lims=(2.0, 2.0)) return fig @@ -197,7 +210,9 @@ def test_draw_hist_kwargs(beta_samples): global_kwargs_path = "./tests/plot/results/global_kwargs" fig, global_kwargs_ax = plt.subplots() - global_kwargs_ax = draw(global_kwargs_ax, contents=[hist], hist_kwargs={"alpha": 0.3}) + global_kwargs_ax = draw( + global_kwargs_ax, contents=[hist], hist_kwargs={"alpha": 0.3} + ) save_figure(global_kwargs_path, fig, ["png"]) hist = Histogram.from_hdf5(filename, dataname, alpha=0.3) @@ -206,45 +221,55 @@ def test_draw_hist_kwargs(beta_samples): local_kwargs_ax = draw(local_kwargs_ax, contents=[hist], hist_kwargs={"alpha": 1.0}) save_figure(local_kwargs_path, fig, ["png"]) - assert mpl_comp.compare_images( - expected=default_kwargs_path + ".png", - actual=bins_kwargs_path + ".png", - tol=0.001, - ) is not None, "Changing bin number did not result in different plot" - - assert mpl_comp.compare_images( - expected=default_kwargs_path + ".png", - actual=global_kwargs_path + ".png", - tol=0.001, - ) is not None, "Changing global kwargs in `draw` did not result in different plot" - - assert mpl_comp.compare_images( - expected=local_kwargs_path + ".png", - actual=global_kwargs_path + ".png", - tol=0.001, - ) is None, "Overriding global with `Histogram` specific kwargs did not work" + assert ( + mpl_comp.compare_images( + expected=default_kwargs_path + ".png", + actual=bins_kwargs_path + ".png", + tol=0.001, + ) + is not None + ), "Changing bin number did not result in different plot" + + assert ( + mpl_comp.compare_images( + expected=default_kwargs_path + ".png", + actual=global_kwargs_path + ".png", + tol=0.001, + ) + is not None + ), "Changing global kwargs in `draw` did not result in different plot" + + assert ( + mpl_comp.compare_images( + expected=local_kwargs_path + ".png", + actual=global_kwargs_path + ".png", + tol=0.001, + ) + is None + ), "Overriding global with `Histogram` specific kwargs did not work" def test_save_figure(capsys): """Check that figures get stored correctly.""" - x = np.linspace(0., 2*np.pi, 200) + x = np.linspace(0.0, 2 * np.pi, 200) y = np.sin(x) fig, ax = plt.subplots(figsize=get_size()) - ax.plot(x,y) + ax.plot(x, y) output_path = "./tests/plot/results/sine" formats = ["png", "svg"] - expected_output = ( - "✓ Saved matplotlib figure.\n" - ) + expected_output = "✓ Saved matplotlib figure.\n" save_figure(output_path, fig, formats) save_figure_capture = capsys.readouterr() - assert mpl_comp.compare_images( - expected="./tests/plot/baseline/sine.png", - actual="./tests/plot/results/sine.png", - tol=0., - ) is None, "PNG of figure was not stored correctly." + assert ( + mpl_comp.compare_images( + expected="./tests/plot/baseline/sine.png", + actual="./tests/plot/results/sine.png", + tol=0.0, + ) + is None + ), "PNG of figure was not stored correctly." # Commented out, because I recently got the following message from matplotlib: # `SKIPPED (Don't know how to convert .svg files to png)` diff --git a/tests/predict/predict_utils_test.py b/tests/predict/predict_utils_test.py index a578f6f..06fb7c7 100644 --- a/tests/predict/predict_utils_test.py +++ b/tests/predict/predict_utils_test.py @@ -1,12 +1,10 @@ -""" -Test utilities of the predict submodule. -""" +"""Test utilities of the predict submodule.""" + from lyscripts.compute.utils import complete_pattern def test_clean_pattern(): - """ - Test the utility function that cleans the involvement patterns from the + """Test the utility function that cleans the involvement patterns from the `params.yaml` file """ empty_pattern = {} @@ -20,13 +18,13 @@ def test_clean_pattern(): assert empty_cleaned == { "ipsi": {"I": None, "II": None, "III": None}, - "contra": {"I": None, "II": None, "III": None} + "contra": {"I": None, "II": None, "III": None}, }, "Empty pattern does not get filled correctly." assert one_pos_cleaned == { "ipsi": {"I": None, "II": True, "III": None}, - "contra": {"I": None, "II": None, "III": None} + "contra": {"I": None, "II": None, "III": None}, }, "Pattern with one positive LNL not cleaned properly." assert nums_cleaned == { "ipsi": {"I": True, "II": None, "III": None}, - "contra": {"I": None, "II": None, "III": False} + "contra": {"I": None, "II": None, "III": False}, }, "Number pattern cleaned wrongly." diff --git a/tests/predict/prevalences_test.py b/tests/predict/prevalences_test.py index 4ae9b63..19ac051 100644 --- a/tests/predict/prevalences_test.py +++ b/tests/predict/prevalences_test.py @@ -1,6 +1,5 @@ -""" -Test the functions of the prevalence prediction submodule. -""" +"""Test the functions of the prevalence prediction submodule.""" + import pandas as pd import pytest @@ -12,16 +11,16 @@ def test_get_match_idx(): """Test if the pattern dictionaries & pandas data are compared correctly.""" oneside_pattern = {"I": False, "II": True, "III": None} ignorant_pattern = {"I": None, "II": None, "III": None} - three_patients = pd.DataFrame.from_dict({ - "I": [False, False, True], - "II": [True , True , True], - "III": [True , False, True], - }) + three_patients = pd.DataFrame.from_dict( + { + "I": [False, False, True], + "II": [True, True, True], + "III": [True, False, True], + } + ) lnls = list(oneside_pattern.keys()) - matching_idxs = get_match_idx( - True, oneside_pattern, three_patients, invert=False - ) + matching_idxs = get_match_idx(True, oneside_pattern, three_patients, invert=False) inverted_matching_idxs = get_match_idx( False, oneside_pattern, three_patients, invert=True ) @@ -41,19 +40,16 @@ def test_get_match_idx(): def test_does_midline_ext_match(): - """ - Test the function that returns indices of a `DataFrame` where the midline + """Test the function that returns indices of a `DataFrame` where the midline extension matches. """ - midline_data = pd.DataFrame({ - ("tumor", "1", "extension"): [True, False, None] - }) + midline_data = pd.DataFrame({("tumor", "1", "extension"): [True, False, None]}) keyerr_data = pd.DataFrame({("way", "too", "many", "levels"): [True, False, None]}) midline_match = does_midext_match(midline_data, midext=False) - assert all(midline_match == pd.Series([False, True, False])), ( - "Matching midline extension with correct data does not work." - ) + assert all( + midline_match == pd.Series([False, True, False]) + ), "Matching midline extension with correct data does not work." with pytest.raises(KeyError): _ = does_midext_match(keyerr_data, midext=False) diff --git a/tests/run_doctests.py b/tests/run_doctests.py index dd603fc..3726fad 100644 --- a/tests/run_doctests.py +++ b/tests/run_doctests.py @@ -1,6 +1,5 @@ -""" -Script to run doctests in the modules of `lyscripts`. -""" +"""Script to run doctests in the modules of `lyscripts`.""" + import doctest from lyscripts import utils diff --git a/tests/sample_test.py b/tests/sample_test.py index f526c65..42573f2 100644 --- a/tests/sample_test.py +++ b/tests/sample_test.py @@ -1,11 +1,11 @@ -""" -Test the sampling command with some example patients. +"""Test the sampling command with some example patients. Originally, I wanted to test that the sampling procedure is reproducible, but the `emcee` package does not seem to work with any kind of seed in a reproducible manner. Maybe I am doing something wrong... """ + # pylint: disable=redefined-outer-name import numpy as np import pytest @@ -66,10 +66,14 @@ def test_burnin(sampler: EnsembleSampler): assert sampler.iteration == 100, "Burnin di not run 100 iterations." assert len(burnin_history.steps) == 10, "Burnin history does not have 10 entries." assert np.all( - np.array([ - 0.7147557514447068, 0.9227188150264771, - 0.2629624184410706, 0.6001184115584288, - ]) + np.array( + [ + 0.7147557514447068, + 0.9227188150264771, + 0.2629624184410706, + 0.6001184115584288, + ] + ) == sampler.get_last_sample().coords[0] ), "Not reproducible." @@ -87,6 +91,5 @@ def test_get_starting_state(sampler: EnsembleSampler): check_interval=2, ) assert np.all( - get_starting_state(sampler).coords - == sampler.get_last_sample().coords + get_starting_state(sampler).coords == sampler.get_last_sample().coords ), "State is not the same." diff --git a/tests/utils_test.py b/tests/utils_test.py index fd6c510..c60bcba 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -1,6 +1,5 @@ -""" -Test the core utility functions of the package. -""" +"""Test the core utility functions of the package.""" + import pytest from lymph import models @@ -66,9 +65,13 @@ def test_create_model(params_v1): assert not model.use_central, "Model should not use central" assert model.use_midext_evo, "Model should use midext evolution" assert "early" in model.get_all_distributions(), "Early distribution is missing" - assert not model.get_distribution("early").is_updateable, "Early distribution should not be updateable" + assert not model.get_distribution( + "early" + ).is_updateable, "Early distribution should not be updateable" assert "late" in model.get_all_distributions(), "Late distribution is missing" - assert model.get_distribution("late").is_updateable, "Late distribution should be updateable" + assert model.get_distribution( + "late" + ).is_updateable, "Late distribution should be updateable" assert "CT" in model.get_all_modalities(), "CT modality is missing" assert model.get_modality("CT").spec == 0.76, "CT modality has wrong specificity" assert model.get_modality("CT").sens == 0.81, "CT modality has wrong sensitivity"