From 4b9fffb95b999e962a493518c6a8efe9edbd34ec Mon Sep 17 00:00:00 2001 From: YoelPH Date: Mon, 30 Sep 2024 11:31:54 +0200 Subject: [PATCH 01/20] first changes not functional --- lyscripts/mixture_fit.py | 290 +++++++++++++++++++++++++++++++++++++++ lyscripts/utils.py | 35 +++++ 2 files changed, 325 insertions(+) create mode 100644 lyscripts/mixture_fit.py diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py new file mode 100644 index 0000000..d4e3f38 --- /dev/null +++ b/lyscripts/mixture_fit.py @@ -0,0 +1,290 @@ +""" +Learn the spread probabilities of the HMM for lymphatic tumor progression using +the preprocessed data as input and the mixture model. +""" +# pylint: disable=logging-fstring-interpolation +import argparse +import logging +import os +from collections import namedtuple + +try: + from multiprocess import Pool +except ModuleNotFoundError: + from multiprocessing import Pool + +from pathlib import Path + +import emcee +import numpy as np +import pandas as pd +from lymph import models +from lymixture import LymphMixture +from rich.progress import Progress, TimeElapsedColumn, track + +from lyscripts.utils import ( + create_mixture, + load_patient_data, + load_yaml_params, + to_numpy, +) + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to a ``subparsers`` instance and run its main function when chosen. + + This is called by the parent module that is called via the command line. + """ + parser.add_argument( + "-i", "--input", type=Path, required=True, + help="Path to training data files" + ) + parser.add_argument( + "-o", "--output", type=Path, required=True, + help="Path to the HDF5 file to store the results in" + ) + parser.add_argument( + "--history", type=Path, nargs="?", + help="Path to store the burnin history in (as CSV file)." + ) + + parser.add_argument( + "-w", "--walkers-per-dim", type=int, default=10, + help="Number of walkers per dimension", + ) + parser.add_argument( + "-b", "--burnin", type=int, nargs="?", + help="Number of burnin steps. If not provided, sampler runs until convergence." + ) + parser.add_argument( + "--check-interval", type=int, default=100, + help="Check convergence every `check_interval` steps." + ) + parser.add_argument( + "--trust-fac", type=float, default=50., + help="Factor to trust the autocorrelation time for convergence." + ) + parser.add_argument( + "--rel-thresh", type=float, default=0.05, + help="Relative threshold for convergence." + ) + parser.add_argument( + "-n", "--nsteps", type=int, default=100, + help="Number of MCMC samples to draw, irrespective of thinning." + ) + parser.add_argument( + "-t", "--thin", type=int, default=10, + help="Thinning factor for the MCMC chain." + ) + parser.add_argument( + "-p", "--params", default="./params.yaml", type=Path, + help="Path to parameter file." + ) + parser.add_argument( + "-c", "--cores", type=int, nargs="?", + help=( + "Number of parallel workers (CPU cores/threads) to use. If not provided, " + "it will use all cores. If set to zero, multiprocessing will not be used." + ) + ) + parser.add_argument( + "-s", "--seed", type=int, default=42, + help="Seed value to reproduce the same sampling round." + ) + + parser.set_defaults(run_main=main) + + +MIXTURE = None + +def log_prob_fn(theta: np.array) -> float: + """log probability function using global variables because of pickling.""" + return MIXTURE.likelihood(use_complete = True, given_resps = MIXTURE.get_resps(norm = True)) + + +def check_convergence(params_history, likelihood_history, steps_back_list, absolute_tolerance = 0.01): + current_params = params_history[-1] + current_likelihood = likelihood_history[-1] + for steps_back in steps_back_list: + previous_params = params_history[-steps_back - 1] + if np.allclose(to_numpy(current_params), to_numpy(previous_params)): + logger.info(f"Converged after {len(params_history)} steps. due to parameter similarity") + return True # Return True if any of the steps is close + elif np.isclose(current_likelihood, likelihood_history[-steps_back - 1],rtol = 0, atol = absolute_tolerance): + logger.info(f"Converged after {len(params_history)} steps. due to likelihood similarity") + return True + return False + + + +def run_burnin( + sampler: emcee.EnsembleSampler, + burnin: int | None = None, + check_interval: int = 100, + trust_fac: float = 50.0, + rel_thresh: float = 0.05, +) -> BurninHistory: + """Run the burnin phase of the MCMC sampling. + + This will run the sampler for ``burnin`` steps or (if ``burnin`` is `None`) until + convergence is reached. The convergence criterion is based on the autocorrelation + time of the chain, which is computed every `check_interval` steps. The chain is + considered to have converged if the autocorrelation time is smaller than + `trust_fac` times the number of iterations and the relative change in the + autocorrelation time is smaller than `rel_thresh`. + + The samples of the burnin phase will be stored, such that one can resume a + cancelled run. Also, metrics collected during the burnin phase will be returned + in a :py:obj:`.BurninHistory` namedtuple. This may be used for plotting and + diagnostics. + """ + state = get_starting_state(sampler) + history = BurninHistory([], [], [], []) + num_accepted = 0 + + with Progress( + *Progress.get_default_columns(), + TimeElapsedColumn(), + ) as progress: + task = progress.add_task( + description="[blue]INFO [/blue]Burn-in phase ", + total=burnin, + ) + while sampler.iteration < (burnin or np.inf): + for state in sampler.sample(state, iterations=check_interval): + progress.update(task, advance=1) + + new_acor_time = sampler.get_autocorr_time(tol=0).mean() + old_acor_time = history.acor_times[-1] if len(history.acor_times) > 0 else np.inf + + new_accept_frac = ( + (np.sum(sampler.backend.accepted) - num_accepted) + / (sampler.nwalkers * check_interval) + ) + num_accepted = np.sum(sampler.backend.accepted) + + history.steps.append(sampler.iteration) + history.acor_times.append(new_acor_time) + history.accept_fracs.append(new_accept_frac) + history.max_log_probs.append(np.max(state.log_prob)) + + is_converged = burnin is None + is_converged &= new_acor_time * trust_fac < sampler.iteration + is_converged &= np.abs(new_acor_time - old_acor_time) / new_acor_time < rel_thresh + + if is_converged: + break + + if is_converged: + logger.info(f"Converged after {sampler.iteration} steps.") + logger.info(f"Acceptance fraction: {sampler.acceptance_fraction.mean():.2%}") + return history + + +def run_sampling( + sampler: emcee.EnsembleSampler, + nsteps: int, + thin: int, +) -> None: + """Run the MCMC sampling phase to produce `nsteps` samples. + + This sampling will definitely produce `nsteps` samples, irrespective of the `thin` + parameter, which controls how many steps in between two stored samples are skipped. + The samples will be stored in the backend of the `sampler`. + + Note that this will reset the `sampler`'s backend, assuming the stored samples are + from the burnin phase. + """ + state = get_starting_state(sampler) + sampler.backend.reset(sampler.nwalkers, sampler.ndim) + + for _sample in track( + sequence=sampler.sample(state, iterations=nsteps * thin, thin=thin, store=True), + description="[blue]INFO [/blue]Sampling phase", + total=nsteps * thin, + ): + continue + + +def run_EM(): + """Run the EM algorithm to determine the optimal parameters. + """ + is_converged = False + iteration = 0 + params_history = [] + likelihood_history = [] + params_history.append(params.copy()) + likelihood_history.append(MIXTURE.likelihood(use_complete=False)) + # Number of steps to look back for convergence + look_back_steps = 3 + + while not is_converged: + print(iteration) + print(likelihood_history[-1]) + latent = LymphMixture.em.expectation(MIXTURE, params) + params = LymphMixture.em.maximization(MIXTURE, latent) + + # Append current params and likelihood to history + params_history.append(params.copy()) + likelihood_history.append(MIXTURE.likelihood(use_complete=False)) + + # Check if converged + if iteration >= 3: # Ensure enough history is available + is_converged = check_convergence(params_history, likelihood_history,list(range(1,look_back_steps+1))) + + return params_history, likelihood_history + +def main(args: argparse.Namespace) -> None: + """Main function to run the EM algorithm for a mixture model""" + # as recommended in https://emcee.readthedocs.io/en/stable/tutorials/parallel/# + os.environ["OMP_NUM_THREADS"] = "1" + + params = load_yaml_params(args.params) + inference_data = load_patient_data(args.input) + + # ugly, but necessary for pickling + global MIXTURE + MIXTURE = create_mixture(params) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE, models.Unilateral): + side = params["model"].get("side", "ipsi") + MIXTURE.load_patient_data(inference_data, side=side, mapping=mapping) + else: + raise "Only Unilateral has been implemented so far" + + # emcee does not support numpy's new random number generator yet. + rng = np.random.default_rng(params["em"].get("seed", 42)) + starting_values = {k: rng.uniform() for k in MIXTURE.get_params()} + MIXTURE.set_params(**starting_values) + MIXTURE.normalize_mixture_coefs() + params_history, likelihood_history = run_EM() + + if args.history is not None: + logger.info(f"Saving history to {args.history}.") + burnin_history_df = pd.DataFrame(burnin_history._asdict()).set_index("steps") + burnin_history_df.to_csv(args.history, index=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/utils.py b/lyscripts/utils.py index 57648ec..9ea07ad 100644 --- a/lyscripts/utils.py +++ b/lyscripts/utils.py @@ -17,6 +17,7 @@ from deprecated import deprecated from emcee.backends import HDFBackend from lymph import diagnosis_times, models, types +from lymixture import LymphMixture from rich.console import Console from rich.logging import RichHandler from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn @@ -283,6 +284,36 @@ def create_model(config: dict[str, Any], config_version: int = 0) -> types.Model return model +@log_state() +def create_mixture(config: dict[str, Any], config_version: int = 0) -> types.Model: + """Create a model instance as defined by a ``config`` dictionary.""" + if version := config.get("version", config_version) < 1: + raise LyScriptsWarning(f"{version=} unsupported", level="error") + + if (graph_config := config.get("graph")) is None: + raise LyScriptsWarning("No graph definition found in YAML file", level="error") + + if (model_config := config.get("mixture")) is None: + raise LyScriptsWarning("No mixture definition found in YAML file", level="error") + + graph_dict = graph_from_config(graph_config) + model_cls_name, _, cls_meth_name = model_config["class"].partition(".") + if model_cls_name != 'Unilateral': + raise LyScriptsWarning("The mixture model has only been implemented for Unilateral so far", level = "error") + model_cls = getattr(models, model_cls_name) + model_kwargs = model_config.get("kwargs", {}) + model_num_components = model_config.get('num_components') + model_kwargs['graph'] = graph_dict + mixture = LymphMixture(model_cls = model_cls, model_kwargs = model_kwargs, num_components = model_num_components) + + assign_modalities(model=mixture, config=config.get("modalities", {})) + + for t_stage, dist_config in model_config.get("distributions", {}).items(): + distribution = create_distribution(dist_config) + mixture.set_distribution(t_stage, distribution) + + return mixture + def get_dict_depth(nested: dict) -> int: """Get the depth of a nested dictionary. @@ -482,3 +513,7 @@ def make_pattern( ) -> dict[str, bool | None]: """Create a dictionary from a list of bools and Nones.""" return dict(zip(lnls, from_list or [None] * len(lnls))) + + +def to_numpy(params: dict[str, float]) -> np.ndarray: + return np.array([p for p in params.values()]) \ No newline at end of file From d00a1c1959957c128c0c9c8343bce640bfd65d38 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Fri, 4 Oct 2024 09:12:08 +0200 Subject: [PATCH 02/20] added mixture script --- lyscripts/__init__.py | 4 +- lyscripts/mixture_fit.py | 96 ++-------------------------------------- 2 files changed, 6 insertions(+), 94 deletions(-) diff --git a/lyscripts/__init__.py b/lyscripts/__init__.py index 0fe8e89..6e4aa11 100644 --- a/lyscripts/__init__.py +++ b/lyscripts/__init__.py @@ -13,7 +13,7 @@ import rich from rich_argparse import RichHelpFormatter -from lyscripts import app, compute, data, evaluate, plot, sample, temp_schedule +from lyscripts import app, compute, data, evaluate, plot, mixture_fit, temp_schedule from lyscripts._version import version from lyscripts.utils import CustomRichHandler, console @@ -114,7 +114,7 @@ def main(): data._add_parser(subparsers, help_formatter=parser.formatter_class) evaluate._add_parser(subparsers, help_formatter=parser.formatter_class) plot._add_parser(subparsers, help_formatter=parser.formatter_class) - sample._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_fit._add_parser(subparsers, help_formatter=parser.formatter_class) temp_schedule._add_parser(subparsers, help_formatter=parser.formatter_class) args = parser.parse_args() diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index d4e3f38..a981d87 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -132,96 +132,6 @@ def check_convergence(params_history, likelihood_history, steps_back_list, absol return False - -def run_burnin( - sampler: emcee.EnsembleSampler, - burnin: int | None = None, - check_interval: int = 100, - trust_fac: float = 50.0, - rel_thresh: float = 0.05, -) -> BurninHistory: - """Run the burnin phase of the MCMC sampling. - - This will run the sampler for ``burnin`` steps or (if ``burnin`` is `None`) until - convergence is reached. The convergence criterion is based on the autocorrelation - time of the chain, which is computed every `check_interval` steps. The chain is - considered to have converged if the autocorrelation time is smaller than - `trust_fac` times the number of iterations and the relative change in the - autocorrelation time is smaller than `rel_thresh`. - - The samples of the burnin phase will be stored, such that one can resume a - cancelled run. Also, metrics collected during the burnin phase will be returned - in a :py:obj:`.BurninHistory` namedtuple. This may be used for plotting and - diagnostics. - """ - state = get_starting_state(sampler) - history = BurninHistory([], [], [], []) - num_accepted = 0 - - with Progress( - *Progress.get_default_columns(), - TimeElapsedColumn(), - ) as progress: - task = progress.add_task( - description="[blue]INFO [/blue]Burn-in phase ", - total=burnin, - ) - while sampler.iteration < (burnin or np.inf): - for state in sampler.sample(state, iterations=check_interval): - progress.update(task, advance=1) - - new_acor_time = sampler.get_autocorr_time(tol=0).mean() - old_acor_time = history.acor_times[-1] if len(history.acor_times) > 0 else np.inf - - new_accept_frac = ( - (np.sum(sampler.backend.accepted) - num_accepted) - / (sampler.nwalkers * check_interval) - ) - num_accepted = np.sum(sampler.backend.accepted) - - history.steps.append(sampler.iteration) - history.acor_times.append(new_acor_time) - history.accept_fracs.append(new_accept_frac) - history.max_log_probs.append(np.max(state.log_prob)) - - is_converged = burnin is None - is_converged &= new_acor_time * trust_fac < sampler.iteration - is_converged &= np.abs(new_acor_time - old_acor_time) / new_acor_time < rel_thresh - - if is_converged: - break - - if is_converged: - logger.info(f"Converged after {sampler.iteration} steps.") - logger.info(f"Acceptance fraction: {sampler.acceptance_fraction.mean():.2%}") - return history - - -def run_sampling( - sampler: emcee.EnsembleSampler, - nsteps: int, - thin: int, -) -> None: - """Run the MCMC sampling phase to produce `nsteps` samples. - - This sampling will definitely produce `nsteps` samples, irrespective of the `thin` - parameter, which controls how many steps in between two stored samples are skipped. - The samples will be stored in the backend of the `sampler`. - - Note that this will reset the `sampler`'s backend, assuming the stored samples are - from the burnin phase. - """ - state = get_starting_state(sampler) - sampler.backend.reset(sampler.nwalkers, sampler.ndim) - - for _sample in track( - sequence=sampler.sample(state, iterations=nsteps * thin, thin=thin, store=True), - description="[blue]INFO [/blue]Sampling phase", - total=nsteps * thin, - ): - continue - - def run_EM(): """Run the EM algorithm to determine the optimal parameters. """ @@ -278,8 +188,10 @@ def main(args: argparse.Namespace) -> None: if args.history is not None: logger.info(f"Saving history to {args.history}.") - burnin_history_df = pd.DataFrame(burnin_history._asdict()).set_index("steps") - burnin_history_df.to_csv(args.history, index=True) + likelihood_history = pd.DataFrame(likelihood_history).set_index("steps") + likelihood_history.to_csv(args.likelihood_history, index=True) + params_history = pd.DataFrame(params_history).set_index("steps") + params_history.to_csv(args.params_history, index=True) if __name__ == "__main__": From 10b05da43fd39bfa04405753db91289a8d5d8bc2 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Fri, 4 Oct 2024 21:27:16 +0200 Subject: [PATCH 03/20] change! Extended scripts for dvc. Not functional yet --- lyscripts/mixture_fit.py | 14 +++++++------- lyscripts/utils.py | 8 +++++--- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index a981d87..213a5c2 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -55,10 +55,10 @@ def _add_arguments(parser: argparse.ArgumentParser): "-i", "--input", type=Path, required=True, help="Path to training data files" ) - parser.add_argument( - "-o", "--output", type=Path, required=True, - help="Path to the HDF5 file to store the results in" - ) + # parser.add_argument( + # "-o", "--output", type=Path, required=True, + # help="Path to the HDF5 file to store the results in" + # ) parser.add_argument( "--history", type=Path, nargs="?", help="Path to store the burnin history in (as CSV file)." @@ -113,7 +113,7 @@ def _add_arguments(parser: argparse.ArgumentParser): MIXTURE = None -def log_prob_fn(theta: np.array) -> float: +def log_prob_fn() -> float: """log probability function using global variables because of pickling.""" return MIXTURE.likelihood(use_complete = True, given_resps = MIXTURE.get_resps(norm = True)) @@ -189,9 +189,9 @@ def main(args: argparse.Namespace) -> None: if args.history is not None: logger.info(f"Saving history to {args.history}.") likelihood_history = pd.DataFrame(likelihood_history).set_index("steps") - likelihood_history.to_csv(args.likelihood_history, index=True) + likelihood_history.to_csv(args.history_dir + 'llh', index=True) params_history = pd.DataFrame(params_history).set_index("steps") - params_history.to_csv(args.params_history, index=True) + params_history.to_csv(args.history_dir + 'params', index=True) if __name__ == "__main__": diff --git a/lyscripts/utils.py b/lyscripts/utils.py index 9ea07ad..d51026a 100644 --- a/lyscripts/utils.py +++ b/lyscripts/utils.py @@ -287,13 +287,13 @@ def create_model(config: dict[str, Any], config_version: int = 0) -> types.Model @log_state() def create_mixture(config: dict[str, Any], config_version: int = 0) -> types.Model: """Create a model instance as defined by a ``config`` dictionary.""" - if version := config.get("version", config_version) < 1: + if version := config.get("version", config_version) < 0: raise LyScriptsWarning(f"{version=} unsupported", level="error") if (graph_config := config.get("graph")) is None: raise LyScriptsWarning("No graph definition found in YAML file", level="error") - if (model_config := config.get("mixture")) is None: + if (model_config := config.get("model")) is None: raise LyScriptsWarning("No mixture definition found in YAML file", level="error") graph_dict = graph_from_config(graph_config) @@ -302,8 +302,10 @@ def create_mixture(config: dict[str, Any], config_version: int = 0) -> types.Mod raise LyScriptsWarning("The mixture model has only been implemented for Unilateral so far", level = "error") model_cls = getattr(models, model_cls_name) model_kwargs = model_config.get("kwargs", {}) + if not isinstance(model_kwargs,dict): + model_kwargs = {} model_num_components = model_config.get('num_components') - model_kwargs['graph'] = graph_dict + model_kwargs['graph_dict'] = graph_dict mixture = LymphMixture(model_cls = model_cls, model_kwargs = model_kwargs, num_components = model_num_components) assign_modalities(model=mixture, config=config.get("modalities", {})) From 112140edb84f347538933625c4c442ac53692b9b Mon Sep 17 00:00:00 2001 From: YoelPH Date: Thu, 10 Oct 2024 20:36:54 +0200 Subject: [PATCH 04/20] first kinda functional version --- lyscripts/mixture_fit.py | 20 +++++++++++--------- lyscripts/utils.py | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index 213a5c2..373f571 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -20,6 +20,7 @@ import pandas as pd from lymph import models from lymixture import LymphMixture +from lymixture.em import expectation, maximization from rich.progress import Progress, TimeElapsedColumn, track from lyscripts.utils import ( @@ -137,27 +138,28 @@ def run_EM(): """ is_converged = False iteration = 0 + params = MIXTURE.get_params() params_history = [] likelihood_history = [] params_history.append(params.copy()) - likelihood_history.append(MIXTURE.likelihood(use_complete=False)) + likelihood_history.append(MIXTURE.likelihood()) # Number of steps to look back for convergence look_back_steps = 3 while not is_converged: - print(iteration) - print(likelihood_history[-1]) - latent = LymphMixture.em.expectation(MIXTURE, params) - params = LymphMixture.em.maximization(MIXTURE, latent) + print('iteration',iteration) + print('likelihood', likelihood_history[-1]) + latent = expectation(MIXTURE, params) + params = maximization(MIXTURE, latent) # Append current params and likelihood to history params_history.append(params.copy()) - likelihood_history.append(MIXTURE.likelihood(use_complete=False)) + likelihood_history.append(MIXTURE.likelihood()) # Check if converged if iteration >= 3: # Ensure enough history is available is_converged = check_convergence(params_history, likelihood_history,list(range(1,look_back_steps+1))) - + iteration += 1 return params_history, likelihood_history def main(args: argparse.Namespace) -> None: @@ -173,9 +175,9 @@ def main(args: argparse.Namespace) -> None: MIXTURE = create_mixture(params) mapping = params["model"].get("mapping", None) - if isinstance(MIXTURE, models.Unilateral): + if isinstance(MIXTURE.components[0], models.Unilateral): side = params["model"].get("side", "ipsi") - MIXTURE.load_patient_data(inference_data, side=side, mapping=mapping) + MIXTURE.load_patient_data(inference_data, split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) else: raise "Only Unilateral has been implemented so far" diff --git a/lyscripts/utils.py b/lyscripts/utils.py index d51026a..583335f 100644 --- a/lyscripts/utils.py +++ b/lyscripts/utils.py @@ -308,7 +308,7 @@ def create_mixture(config: dict[str, Any], config_version: int = 0) -> types.Mod model_kwargs['graph_dict'] = graph_dict mixture = LymphMixture(model_cls = model_cls, model_kwargs = model_kwargs, num_components = model_num_components) - assign_modalities(model=mixture, config=config.get("modalities", {})) + # assign_modalities(model=mixture, config=config.get("modalities", {})) for t_stage, dist_config in model_config.get("distributions", {}).items(): distribution = create_distribution(dist_config) From 7ac57e1413347be9b0e4c18160d31b4e0c2bbdfd Mon Sep 17 00:00:00 2001 From: YoelPH Date: Fri, 11 Oct 2024 12:17:51 +0200 Subject: [PATCH 05/20] change: functional script --- lyscripts/mixture_fit.py | 66 ++++++++++------------------------------ lyscripts/utils.py | 6 ++-- 2 files changed, 19 insertions(+), 53 deletions(-) diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index 373f571..7239f7b 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -28,6 +28,7 @@ load_patient_data, load_yaml_params, to_numpy, + assign_modalities ) logger = logging.getLogger(__name__) @@ -62,48 +63,12 @@ def _add_arguments(parser: argparse.ArgumentParser): # ) parser.add_argument( "--history", type=Path, nargs="?", - help="Path to store the burnin history in (as CSV file)." - ) - - parser.add_argument( - "-w", "--walkers-per-dim", type=int, default=10, - help="Number of walkers per dimension", - ) - parser.add_argument( - "-b", "--burnin", type=int, nargs="?", - help="Number of burnin steps. If not provided, sampler runs until convergence." - ) - parser.add_argument( - "--check-interval", type=int, default=100, - help="Check convergence every `check_interval` steps." - ) - parser.add_argument( - "--trust-fac", type=float, default=50., - help="Factor to trust the autocorrelation time for convergence." - ) - parser.add_argument( - "--rel-thresh", type=float, default=0.05, - help="Relative threshold for convergence." - ) - parser.add_argument( - "-n", "--nsteps", type=int, default=100, - help="Number of MCMC samples to draw, irrespective of thinning." - ) - parser.add_argument( - "-t", "--thin", type=int, default=10, - help="Thinning factor for the MCMC chain." + help="Path to store history in (as CSV file)." ) parser.add_argument( "-p", "--params", default="./params.yaml", type=Path, help="Path to parameter file." ) - parser.add_argument( - "-c", "--cores", type=int, nargs="?", - help=( - "Number of parallel workers (CPU cores/threads) to use. If not provided, " - "it will use all cores. If set to zero, multiprocessing will not be used." - ) - ) parser.add_argument( "-s", "--seed", type=int, default=42, help="Seed value to reproduce the same sampling round." @@ -133,7 +98,7 @@ def check_convergence(params_history, likelihood_history, steps_back_list, absol return False -def run_EM(): +def run_EM(tolerance): """Run the EM algorithm to determine the optimal parameters. """ is_converged = False @@ -142,7 +107,7 @@ def run_EM(): params_history = [] likelihood_history = [] params_history.append(params.copy()) - likelihood_history.append(MIXTURE.likelihood()) + likelihood_history.append(MIXTURE.likelihood(use_complete = False)) # Number of steps to look back for convergence look_back_steps = 3 @@ -154,11 +119,11 @@ def run_EM(): # Append current params and likelihood to history params_history.append(params.copy()) - likelihood_history.append(MIXTURE.likelihood()) + likelihood_history.append(MIXTURE.likelihood(use_complete = False)) # Check if converged if iteration >= 3: # Ensure enough history is available - is_converged = check_convergence(params_history, likelihood_history,list(range(1,look_back_steps+1))) + is_converged = check_convergence(params_history, likelihood_history,list(range(1,look_back_steps+1)),tolerance) iteration += 1 return params_history, likelihood_history @@ -169,7 +134,6 @@ def main(args: argparse.Namespace) -> None: params = load_yaml_params(args.params) inference_data = load_patient_data(args.input) - # ugly, but necessary for pickling global MIXTURE MIXTURE = create_mixture(params) @@ -177,23 +141,25 @@ def main(args: argparse.Namespace) -> None: mapping = params["model"].get("mapping", None) if isinstance(MIXTURE.components[0], models.Unilateral): side = params["model"].get("side", "ipsi") - MIXTURE.load_patient_data(inference_data, split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) + MIXTURE.load_patient_data(inference_data, split_by= params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) + assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) + else: raise "Only Unilateral has been implemented so far" - # emcee does not support numpy's new random number generator yet. rng = np.random.default_rng(params["em"].get("seed", 42)) starting_values = {k: rng.uniform() for k in MIXTURE.get_params()} MIXTURE.set_params(**starting_values) MIXTURE.normalize_mixture_coefs() - params_history, likelihood_history = run_EM() + tolerance = params['model'].get('lihelihood_tolerance', 0.01) + params_history, likelihood_history = run_EM(tolerance = tolerance) if args.history is not None: - logger.info(f"Saving history to {args.history}.") - likelihood_history = pd.DataFrame(likelihood_history).set_index("steps") - likelihood_history.to_csv(args.history_dir + 'llh', index=True) - params_history = pd.DataFrame(params_history).set_index("steps") - params_history.to_csv(args.history_dir + 'params', index=True) + logger.info(f"Saving history to {args.history_dir}.") + likelihood_history = pd.DataFrame(likelihood_history) + likelihood_history.to_csv(args.history_dir + '/llh', index=True) + params_history = pd.DataFrame(params_history) + params_history.to_csv(args.history_dir + '/params', index=True) if __name__ == "__main__": diff --git a/lyscripts/utils.py b/lyscripts/utils.py index 583335f..eaeee49 100644 --- a/lyscripts/utils.py +++ b/lyscripts/utils.py @@ -275,7 +275,7 @@ def create_model(config: dict[str, Any], config_version: int = 0) -> types.Model model_kwargs = model_config.get("kwargs", {}) model = model_cls(graph_dict, **model_kwargs) - assign_modalities(model=model, config=config.get("modalities", {})) + assign_modalities(model=model, config=config.get("inference_modalities", {})) for t_stage, dist_config in model_config.get("distributions", {}).items(): distribution = create_distribution(dist_config) @@ -308,8 +308,8 @@ def create_mixture(config: dict[str, Any], config_version: int = 0) -> types.Mod model_kwargs['graph_dict'] = graph_dict mixture = LymphMixture(model_cls = model_cls, model_kwargs = model_kwargs, num_components = model_num_components) - # assign_modalities(model=mixture, config=config.get("modalities", {})) - + #note: modalities can't be set here, as we need to add the data first to define the number of subgroups + for t_stage, dist_config in model_config.get("distributions", {}).items(): distribution = create_distribution(dist_config) mixture.set_distribution(t_stage, distribution) From b72868cc2667ea5013531a3909dd4082a2d260f6 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Wed, 16 Oct 2024 09:27:38 +0200 Subject: [PATCH 06/20] extended to add lymixture scripts. plotting not fully funcional yet --- lyscripts/__init__.py | 2 +- lyscripts/mixture_fit.py | 17 ++-- lyscripts/plot/__init__.py | 4 +- lyscripts/plot/__main__.py | 3 +- lyscripts/plot/mixture_plot.py | 77 +++++++++++++++++ lyscripts/plot/simplex_plot.py | 150 +++++++++++++++++++++++++++++++++ lyscripts/plot/utils.py | 50 +++++++++++ 7 files changed, 292 insertions(+), 11 deletions(-) create mode 100644 lyscripts/plot/mixture_plot.py create mode 100644 lyscripts/plot/simplex_plot.py diff --git a/lyscripts/__init__.py b/lyscripts/__init__.py index 6e4aa11..44f6355 100644 --- a/lyscripts/__init__.py +++ b/lyscripts/__init__.py @@ -13,7 +13,7 @@ import rich from rich_argparse import RichHelpFormatter -from lyscripts import app, compute, data, evaluate, plot, mixture_fit, temp_schedule +from lyscripts import plot, app, compute, data, evaluate, mixture_fit, temp_schedule from lyscripts._version import version from lyscripts.utils import CustomRichHandler, console diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index 7239f7b..e471cc9 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -151,16 +151,17 @@ def main(args: argparse.Namespace) -> None: starting_values = {k: rng.uniform() for k in MIXTURE.get_params()} MIXTURE.set_params(**starting_values) MIXTURE.normalize_mixture_coefs() - tolerance = params['model'].get('lihelihood_tolerance', 0.01) + tolerance = params['model'].get('likelihood_tolerance', 0.01) params_history, likelihood_history = run_EM(tolerance = tolerance) - if args.history is not None: - logger.info(f"Saving history to {args.history_dir}.") - likelihood_history = pd.DataFrame(likelihood_history) - likelihood_history.to_csv(args.history_dir + '/llh', index=True) - params_history = pd.DataFrame(params_history) - params_history.to_csv(args.history_dir + '/params', index=True) - + history_dir = params['general']['history_dir'] + logger.info(f"Saving history to {history_dir}.") + llh_history = pd.DataFrame(likelihood_history) + llh_history.columns = ['likelihoods'] + llh_history.to_csv(history_dir + '/llh.csv', index=False) + param_history = pd.DataFrame(params_history) + param_history.to_csv(history_dir + '/params.csv', index=False) + MIXTURE.get_mixture_coefs().to_csv(history_dir + '/mixture_coef.csv', index=False) if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) diff --git a/lyscripts/plot/__init__.py b/lyscripts/plot/__init__.py index 6bc0748..e762930 100644 --- a/lyscripts/plot/__init__.py +++ b/lyscripts/plot/__init__.py @@ -6,7 +6,7 @@ import argparse from pathlib import Path -from lyscripts.plot import corner, histograms, thermo_int +from lyscripts.plot import corner, histograms, thermo_int, mixture_plot def _add_parser( @@ -24,3 +24,5 @@ def _add_parser( corner._add_parser(subparsers, help_formatter=parser.formatter_class) histograms._add_parser(subparsers, help_formatter=parser.formatter_class) thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) + diff --git a/lyscripts/plot/__main__.py b/lyscripts/plot/__main__.py index dc34a9f..8c5cf81 100644 --- a/lyscripts/plot/__main__.py +++ b/lyscripts/plot/__main__.py @@ -1,7 +1,7 @@ import argparse from lyscripts import RichDefaultHelpFormatter, exit_cli -from lyscripts.plot import corner, histograms, thermo_int +from lyscripts.plot import corner, histograms, thermo_int, mixture_plot # I need another __main__ guard here, because otherwise pdoc tries to run this if __name__ == "__main__": @@ -18,6 +18,7 @@ corner._add_parser(subparsers, help_formatter=parser.formatter_class) histograms._add_parser(subparsers, help_formatter=parser.formatter_class) thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) args = parser.parse_args() args.run_main(args) diff --git a/lyscripts/plot/mixture_plot.py b/lyscripts/plot/mixture_plot.py new file mode 100644 index 0000000..e25098f --- /dev/null +++ b/lyscripts/plot/mixture_plot.py @@ -0,0 +1,77 @@ +import argparse +import logging +from pathlib import Path + +import pandas as pd +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap + +from lyscripts.plot.utils import COLORS, save_figure + +logger = logging.getLogger(__name__) + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument( + "--input", type=Path, + help="File path of mixture coefficients" + ) + parser.add_argument( + "--output", type=Path, + help="Output path for the plot" + ) + + parser.set_defaults(run_main=main) + +def main(args: argparse.Namespace): + tmp = LinearSegmentedColormap.from_list("tmp", [COLORS['green'], COLORS['red']], N=128) + mixture_df = pd.read_csv(args.input) + + # Transpose the matrix to rotate by 90° + matrix_rotated = mixture_df.T + + # Create the figure and axis + fig, ax = plt.subplots(figsize=(8, 10)) + + # Display the rotated matrix using imshow + cax = ax.imshow(matrix_rotated.values, cmap=tmp, origin='upper') + + # Loop over the data and create text annotations + for i in range(matrix_rotated.shape[0]): # Rows (previously columns) + for j in range(matrix_rotated.shape[1]): # Columns (previously rows) + value = matrix_rotated.iloc[i, j] + ax.text(j, i, f"{value:.2f}", ha="center", va="center", + color="white", fontsize=12) + + + # Optional: Set axis labels and title + ax.set_xticks(range(matrix_rotated.shape[1])) + ax.set_xticklabels(mixture_df.index, fontsize = 12) # Original row labels + ax.set_yticks(range(matrix_rotated.shape[0])) + ax.set_yticklabels(mixture_df.columns, fontsize = 12) # Original column labels + ax.set_title("Mixture Coefficients per subsite", fontsize = 16) + save_figure(args.output, fig, formats=["png", "svg"]) + logger.info(f"Mixture parameter matrix saved") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) + diff --git a/lyscripts/plot/simplex_plot.py b/lyscripts/plot/simplex_plot.py new file mode 100644 index 0000000..b2a69a1 --- /dev/null +++ b/lyscripts/plot/simplex_plot.py @@ -0,0 +1,150 @@ +import argparse +import logging +import numpy as np +from pathlib import Path + +import pandas as pd +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D + + +from lyscripts.plot.utils import COLORS, save_figure, p_to_xyz, add_perpendicular_crosses_3d +from lyscripts.utils import load_yaml_params + +logger = logging.getLogger(__name__) + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument( + "--input", type=Path, + help="File path of the mixture coefficients" + ) + parser.add_argument( + "--output", type=Path, + help="Output path for the plot" + ) + parser.add_argument( + "-p", "--params", default="./params.yaml", type=Path, + help="Path to parameter file." + ) + + parser.set_defaults(run_main=main) + +def plot_3d_simplex(mixture_df, data): + subsites = list(mixture_df.columns) + simplex_matrix = np.zeros((len(subsites), 3)) + for index, subsite in enumerate(subsites): + simplex_matrix[index] = p_to_xyz(mixture_df[subsite]) + odered_value_counts = data['tumor']['1']['subsite'].value_counts()[subsites] + sizes = odered_value_counts * 3 + + # Sizes for each point + ordered_value_counts = data['tumor']['1']['subsite'].value_counts()[subsites] + sizes = np.array(ordered_value_counts) * 1 # Adjust the scaling if necessary + + # Create a 3D plot + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(111, projection='3d') + + # Extract coordinates from simplex_matrix + x_coords = simplex_matrix[:, 0] + y_coords = simplex_matrix[:, 1] + z_coords = simplex_matrix[:, 2] + + # Scatter plot with specific colors and sizes + scatter = ax.scatter(x_coords, y_coords, z_coords, c=colors_ordered, s=sizes, marker='o') + + # Add text labels for each subsite + for i, subsite in enumerate(subsites): + ax.text(x_coords[i], y_coords[i], z_coords[i], subsite) + x_extremes = [-1,1,0,0] + y_extremes = [-0.5773502691896258,-0.5773502691896258,1.1547005383792517,0] + z_extremes = [-0.5773502691896258,-0.5773502691896258,-0.5773502691896258,1.2247448713915892] + plotter_x = x_extremes.copy() + plotter_x.append(x_extremes[0]) + plotter_x.append(x_extremes[2]) + plotter_y = y_extremes.copy() + plotter_y.append(y_extremes[0]) + plotter_y.append(y_extremes[2]) + plotter_z = z_extremes.copy() + plotter_z.append(z_extremes[0]) + plotter_z.append(z_extremes[2]) + + ax.plot(plotter_x, plotter_y,plotter_z, c='k') + ax.plot([x_extremes[1],x_extremes[3]], [y_extremes[1],y_extremes[3]],[z_extremes[1],z_extremes[3]], c='k') + extremes = np.array((x_extremes,y_extremes,z_extremes)).T + center0 = (extremes[1]+extremes[2]+extremes[3])/3 + center1 = (extremes[0]+extremes[2]+extremes[3])/3 + center2 = (extremes[0]+extremes[1]+extremes[3])/3 + center3 = (extremes[0]+extremes[1]+extremes[2])/3 + + component_larynx = mixture.get_mixture_coefs()['C32.0'].argmax() + component_oropharynx = mixture.get_mixture_coefs()['C01'].argmax() + component_oral_cavity = mixture.get_mixture_coefs()['C03'].argmax() + component_hypopharynx = mixture.get_mixture_coefs()['C13'].argmax() + + ax.text(extremes[component_larynx,0], extremes[component_larynx,1], extremes[component_larynx,2], "Larynx like", fontsize=10, ha='right', va='top',c = usz_orange) + ax.text(extremes[component_oropharynx,0], extremes[component_oropharynx,1], extremes[component_oropharynx,2], "Oropharynx like", fontsize=10, ha='left', va='top',c = usz_green) + ax.text(extremes[component_oral_cavity,0], extremes[component_oral_cavity,1], extremes[component_oral_cavity,2], "Oral cavity like", fontsize=10, ha='left', va='top',c = usz_blue) + ax.text(extremes[component_hypopharynx,0], extremes[component_hypopharynx,1], extremes[component_hypopharynx,2], "Hypopharynx like", fontsize=10, ha='left', va='top',c = usz_red) + + + + add_perpendicular_crosses_3d(ax, extremes[0, 0], extremes[0, 1], extremes[0, 2], center0[0], center0[1], center0[2]) + plt.plot([extremes[0,0], center0[0]], [extremes[0,1], center0[1]],[extremes[0,2],center0[2]], color='gray', linestyle='--', linewidth=1) + add_perpendicular_crosses_3d(ax, extremes[1, 0], extremes[1, 1], extremes[1, 2], center1[0], center1[1], center1[2]) + plt.plot([extremes[1,0], center1[0]], [extremes[1,1], center1[1]],[extremes[1,2],center1[2]], color='gray', linestyle='--', linewidth=1) + add_perpendicular_crosses_3d(ax, extremes[2, 0], extremes[2, 1], extremes[2, 2], center2[0], center2[1], center2[2]) + plt.plot([extremes[2,0], center2[0]], [extremes[2,1], center2[1]],[extremes[2,2],center2[2]], color='gray', linestyle='--', linewidth=1) + add_perpendicular_crosses_3d(ax, extremes[3, 0], extremes[3, 1], extremes[3, 2], center3[0], center3[1], center3[2]) + plt.plot([extremes[3,0], center3[0]], [extremes[3,1], center3[1]],[extremes[3,2],center3[2]], color='gray', linestyle='--', linewidth=1) + + legend_text = [] + for index in range(len(subsites)): + legend_text.append(subsites[index] + ', ' + str(odered_value_counts[index]) + ' patients') + + legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=text) + for color, text in zip(colors_ordered, legend_text)] + + # Add a legend with fixed dot sizes + # plt.legend(handles=legend_elements, loc='upper right', title='Subsites', fontsize='small') + + plt.gca().set_axis_off() + plt.show() + + + +def main(args: argparse.Namespace): + mixture_df = pd.read_csv(args.input) + nr_components = len(mixture_df) + params = load_yaml_params(args.params) + data = params['general']['data'] + if nr_components == 2: + plot_2d_simplex() + elif nr_components == 3: + plot_3d_simplex(mixture_df, data) + else: + logger.info(f"Simplex not supported for {nr_components} components") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) + diff --git a/lyscripts/plot/utils.py b/lyscripts/plot/utils.py index acf78f6..915c21e 100644 --- a/lyscripts/plot/utils.py +++ b/lyscripts/plot/utils.py @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt import numpy as np import scipy as sp +import math from lyscripts.decorators import ( check_input_file_exists, @@ -407,3 +408,52 @@ def save_figure( """Save a ``figure`` to ``output_path`` in every one of the provided ``formats``.""" for frmt in formats: figure.savefig(output_path.with_suffix(f".{frmt}")) + +def p_to_xyz(p): + #Project 4D representation down to 3D representation + s3 = 1/math.sqrt(3.0) + s6 = 1/math.sqrt(6.0) + x = -1*p[0] + 1*p[1] + 0*p[2] + 0*p[3] + y = -s3*p[0] - s3*p[1] + 2*s3*p[2] + 0*p[3] + z = -s3*p[0] - s3*p[1] - s3*p[2] + 3*s6*p[3] + return x, y, z + +def add_perpendicular_crosses_3d(ax, x1, y1, z1, x2, y2, z2, tick_length=0.03): + num_ticks = 6 # Number of ticks + for i in range(num_ticks): + t = i / (num_ticks - 1) + + # Interpolation point (x, y, z) on the line + x_tick = x1 + t * (x2 - x1) + y_tick = y1 + t * (y2 - y1) + z_tick = z1 + t * (z2 - z1) + + # Vector along the line (direction of the line) + line_vec = np.array([x2 - x1, y2 - y1, z2 - z1]) + + # First perpendicular vector (cross product with z-axis) + perp_vec1 = np.cross(line_vec, [0, 0, 1]) + length1 = np.linalg.norm(perp_vec1) + if length1 == 0: # Prevent division by zero (in rare cases when parallel to z-axis) + perp_vec1 = np.cross(line_vec, [1, 0, 0]) # Cross with x-axis instead + length1 = np.linalg.norm(perp_vec1) + perp_vec1 /= length1 # Normalize + + # Second perpendicular vector (cross product with line_vec and perp_vec1) + perp_vec2 = np.cross(line_vec, perp_vec1) + perp_vec2 /= np.linalg.norm(perp_vec2) # Normalize + + # Scale the perpendicular vectors by tick length + perp_vec1 *= tick_length + perp_vec2 *= tick_length + + # Draw the cross (two perpendicular lines) + ax.plot([x_tick - perp_vec1[0], x_tick + perp_vec1[0]], + [y_tick - perp_vec1[1], y_tick + perp_vec1[1]], + [z_tick - perp_vec1[2], z_tick + perp_vec1[2]], color='gray', linewidth=0.8) + + ax.plot([x_tick - perp_vec2[0], x_tick + perp_vec2[0]], + [y_tick - perp_vec2[1], y_tick + perp_vec2[1]], + [z_tick - perp_vec2[2], z_tick + perp_vec2[2]], color='gray', linewidth=0.8) + + ax.text(x_tick, y_tick, z_tick, f'{int(100 - t * 100)}%', fontsize=6, ha='right', va='bottom') \ No newline at end of file From 8cd85e7ad8a4cdae38bb6c4c5f46d78bcb0e4fa0 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Thu, 17 Oct 2024 22:15:11 +0200 Subject: [PATCH 07/20] added plotting --- lyscripts/plot/simplex_plot.py | 275 ++++++++++++++++++++++++--------- lyscripts/plot/utils.py | 26 ++++ 2 files changed, 229 insertions(+), 72 deletions(-) diff --git a/lyscripts/plot/simplex_plot.py b/lyscripts/plot/simplex_plot.py index b2a69a1..1ac8223 100644 --- a/lyscripts/plot/simplex_plot.py +++ b/lyscripts/plot/simplex_plot.py @@ -8,7 +8,7 @@ from matplotlib.lines import Line2D -from lyscripts.plot.utils import COLORS, save_figure, p_to_xyz, add_perpendicular_crosses_3d +from lyscripts.plot.utils import COLORS, SUBSITE_COLORS, save_figure, p_to_xyz, add_perpendicular_crosses_3d from lyscripts.utils import load_yaml_params logger = logging.getLogger(__name__) @@ -41,92 +41,223 @@ def _add_arguments(parser: argparse.ArgumentParser): "-p", "--params", default="./params.yaml", type=Path, help="Path to parameter file." ) - parser.set_defaults(run_main=main) -def plot_3d_simplex(mixture_df, data): + +""" +Visualize the component assignments of the trained mixture model. +""" +import argparse +from pathlib import Path +from matplotlib.ticker import StrMethodFormatter +import numpy as np +import yaml + +import h5py +import matplotlib.pyplot as plt +from tueplots import figsizes, fontsizes +from lyscripts.plot.utils import COLORS as USZ + +from helpers import generate_location_colors + + +def create_parser() -> argparse.ArgumentParser: + """Assemble the parser for the command line arguments.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-m", "--model", type=Path, default="models/mixture.hdf5", + help=( + "Path to the model HDF5 file. Needs to contain a dataset called " + "``em/cluster_assignments``." + ) + ) + parser.add_argument( + "-o", "--output", type=Path, default="figures/cluster_assignments.png", + help="Path to the output file.", + ) + parser.add_argument( + "-p", "--params", type=Path, default="_variables.yml", + help="Path to the parameter file..", + ) + return parser + + +def plot_2d_simplex(mixture_df, data): + _, bottom_ax = plt.subplots() subsites = list(mixture_df.columns) - simplex_matrix = np.zeros((len(subsites), 3)) - for index, subsite in enumerate(subsites): - simplex_matrix[index] = p_to_xyz(mixture_df[subsite]) + cluster_x = mixture_df.loc[0] + cluster_y = [0. for _ in subsites] + annotations = [f"{label}\n({num})" for label, num in num_patients.items()] + bottom_ax.scatter( + cluster_x, cluster_y, + s=[num for num in num_patients.values()], + c=list(generate_location_colors(subsites)), + alpha=0.7, + linewidths=0., + zorder=10, + ) + + sorted_idx = cluster_components.argsort() + sorted_x = cluster_components[sorted_idx] + sorted_annotations = [annotations[i] for i in sorted_idx] + sorted_num = [list(num_patients.values())[i] for i in sorted_idx] + for i, (x, num, annotation) in enumerate(zip(sorted_x, sorted_num, sorted_annotations)): + bottom_ax.annotate( + annotation, + # sqrt, because marker's area grows linearly with patient num, not radius + xy=(x, np.sqrt(0.0000003 * num) * (- 1)**i), + xytext=(x, 0.025 * (- 1)**i), + ha="center", + va="bottom" if i % 2 == 0 else "top", + fontsize="small", + arrowprops={ + "arrowstyle": "-", + "color": USZ["gray"], + "linewidth": 1., + } + ) + + bottom_ax.set_xlabel("assignment to component A") + bottom_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) + top_ax = bottom_ax.secondary_xaxis( + location="top", + functions=(lambda x: 1. - x, lambda x: 1. - x), + ) + top_ax.set_xlabel("assignment to component B") + top_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) + bottom_ax.set_yticks([]) + bottom_ax.grid(axis="x", alpha=0.5, color=USZ["gray"], linestyle=":") + plt.savefig(args.output, bbox_inches="tight", dpi=300) + +# Function to add perpendicular ticks as short lines +def add_perpendicular_ticks(x1, y1, x2, y2, tick_length=0.01): + num_ticks = 6 # Number of ticks including 0% and 100% + for i in range(num_ticks): + t = i / (num_ticks - 1) + x_tick = x1 + t * (x2 - x1) + y_tick = y1 + t * (y2 - y1) + + # Vector along the line + dx = x2 - x1 + dy = y2 - y1 + + # Perpendicular vector + perp_dx = -dy + perp_dy = dx + + # Normalize the perpendicular vector + length = np.sqrt(perp_dx**2 + perp_dy**2) + perp_dx /= length + perp_dy /= length + + # Draw tick as a short perpendicular line + plt.plot([x_tick - tick_length * perp_dx, x_tick + tick_length * perp_dx], [y_tick - tick_length * perp_dy, y_tick + tick_length * perp_dy], color='gray', linewidth=0.8) + plt.text(x_tick, y_tick, f'{int(100 - t * 100)}%', fontsize=8, ha='right', va='bottom') + +def plot_3d_simplex(mixture_df, data, component_names = False): + subsites = list(mixture_df.columns) + colors_ordered = [SUBSITE_COLORS[subsite] for subsite in subsites] + + # Define the plane's normal vector + normal_vector = np.array([1,1,1])/np.sqrt(3) + + v1 = np.array([1,-1,0])/np.sqrt(2) + + # Calculate the second orthogonal vector using the cross product + v2 = np.cross(normal_vector, v1) *-1 + + # Project the point onto the new coordinate system + origin = np.array([0, 1, 0]) + x_origin = origin @ v1 + y_origin = origin @ v2 + + x_vals = mixture_df.T @ v1 - x_origin + y_vals = mixture_df.T @ v2 - y_origin + + extremes = np.array([[1,0,0], + [0,1,0], + [0,0,1]]) + extremes_x = extremes @ v1 - x_origin + extremes_y = extremes @ v2 - y_origin + + # Plot the point in 2D + import matplotlib.pyplot as plt + odered_value_counts = data['tumor']['1']['subsite'].value_counts()[subsites] sizes = odered_value_counts * 3 - # Sizes for each point - ordered_value_counts = data['tumor']['1']['subsite'].value_counts()[subsites] - sizes = np.array(ordered_value_counts) * 1 # Adjust the scaling if necessary - - # Create a 3D plot - fig = plt.figure(figsize=(10, 10)) - ax = fig.add_subplot(111, projection='3d') - - # Extract coordinates from simplex_matrix - x_coords = simplex_matrix[:, 0] - y_coords = simplex_matrix[:, 1] - z_coords = simplex_matrix[:, 2] - - # Scatter plot with specific colors and sizes - scatter = ax.scatter(x_coords, y_coords, z_coords, c=colors_ordered, s=sizes, marker='o') - - # Add text labels for each subsite - for i, subsite in enumerate(subsites): - ax.text(x_coords[i], y_coords[i], z_coords[i], subsite) - x_extremes = [-1,1,0,0] - y_extremes = [-0.5773502691896258,-0.5773502691896258,1.1547005383792517,0] - z_extremes = [-0.5773502691896258,-0.5773502691896258,-0.5773502691896258,1.2247448713915892] - plotter_x = x_extremes.copy() - plotter_x.append(x_extremes[0]) - plotter_x.append(x_extremes[2]) - plotter_y = y_extremes.copy() - plotter_y.append(y_extremes[0]) - plotter_y.append(y_extremes[2]) - plotter_z = z_extremes.copy() - plotter_z.append(z_extremes[0]) - plotter_z.append(z_extremes[2]) - - ax.plot(plotter_x, plotter_y,plotter_z, c='k') - ax.plot([x_extremes[1],x_extremes[3]], [y_extremes[1],y_extremes[3]],[z_extremes[1],z_extremes[3]], c='k') - extremes = np.array((x_extremes,y_extremes,z_extremes)).T - center0 = (extremes[1]+extremes[2]+extremes[3])/3 - center1 = (extremes[0]+extremes[2]+extremes[3])/3 - center2 = (extremes[0]+extremes[1]+extremes[3])/3 - center3 = (extremes[0]+extremes[1]+extremes[2])/3 - - component_larynx = mixture.get_mixture_coefs()['C32.0'].argmax() - component_oropharynx = mixture.get_mixture_coefs()['C01'].argmax() - component_oral_cavity = mixture.get_mixture_coefs()['C03'].argmax() - component_hypopharynx = mixture.get_mixture_coefs()['C13'].argmax() - - ax.text(extremes[component_larynx,0], extremes[component_larynx,1], extremes[component_larynx,2], "Larynx like", fontsize=10, ha='right', va='top',c = usz_orange) - ax.text(extremes[component_oropharynx,0], extremes[component_oropharynx,1], extremes[component_oropharynx,2], "Oropharynx like", fontsize=10, ha='left', va='top',c = usz_green) - ax.text(extremes[component_oral_cavity,0], extremes[component_oral_cavity,1], extremes[component_oral_cavity,2], "Oral cavity like", fontsize=10, ha='left', va='top',c = usz_blue) - ax.text(extremes[component_hypopharynx,0], extremes[component_hypopharynx,1], extremes[component_hypopharynx,2], "Hypopharynx like", fontsize=10, ha='left', va='top',c = usz_red) - - - - add_perpendicular_crosses_3d(ax, extremes[0, 0], extremes[0, 1], extremes[0, 2], center0[0], center0[1], center0[2]) - plt.plot([extremes[0,0], center0[0]], [extremes[0,1], center0[1]],[extremes[0,2],center0[2]], color='gray', linestyle='--', linewidth=1) - add_perpendicular_crosses_3d(ax, extremes[1, 0], extremes[1, 1], extremes[1, 2], center1[0], center1[1], center1[2]) - plt.plot([extremes[1,0], center1[0]], [extremes[1,1], center1[1]],[extremes[1,2],center1[2]], color='gray', linestyle='--', linewidth=1) - add_perpendicular_crosses_3d(ax, extremes[2, 0], extremes[2, 1], extremes[2, 2], center2[0], center2[1], center2[2]) - plt.plot([extremes[2,0], center2[0]], [extremes[2,1], center2[1]],[extremes[2,2],center2[2]], color='gray', linestyle='--', linewidth=1) - add_perpendicular_crosses_3d(ax, extremes[3, 0], extremes[3, 1], extremes[3, 2], center3[0], center3[1], center3[2]) - plt.plot([extremes[3,0], center3[0]], [extremes[3,1], center3[1]],[extremes[3,2],center3[2]], color='gray', linestyle='--', linewidth=1) + fig, ax = plt.subplots(figsize=(8, 6.8)) + + # Plot the points with varying sizes and colors + for i in range(len(x_vals)): + ax.scatter(x_vals[i], y_vals[i], s=sizes[i], color=colors_ordered[i], label=subsites[i]) + ax.text(x_vals[i], y_vals[i], subsites[i], fontsize=9, ha='center', va='center') legend_text = [] for index in range(len(subsites)): legend_text.append(subsites[index] + ', ' + str(odered_value_counts[index]) + ' patients') - legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=text) - for color, text in zip(colors_ordered, legend_text)] + legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=subsite) + for color, subsite in zip(colors_ordered, legend_text)] # Add a legend with fixed dot sizes - # plt.legend(handles=legend_elements, loc='upper right', title='Subsites', fontsize='small') + ax.legend(handles=legend_elements, loc='upper right', title='Subsites', fontsize='small') + + # Connect the points + ax.plot(extremes_x, extremes_y, color='black', alpha=0.5) + + # Close the triangle by connecting the last point to the first + ax.plot([extremes_x[-1], extremes_x[0]], [extremes_y[-1], extremes_y[0]], color='black', alpha=0.5) + + # Calculate midpoints of each side of the triangle + midpoints_x = (extremes_x[0] + extremes_x[1]) / 2, (extremes_x[1] + extremes_x[2]) / 2, (extremes_x[2] + extremes_x[0]) / 2 + midpoints_y = (extremes_y[0] + extremes_y[1]) / 2, (extremes_y[1] + extremes_y[2]) / 2, (extremes_y[2] + extremes_y[0]) / 2 + + # Draw lines from each vertex to the midpoint of the opposite side + ax.plot([extremes_x[0], midpoints_x[1]], [extremes_y[0], midpoints_y[1]], color='gray', linestyle='--', linewidth=1) + ax.plot([extremes_x[1], midpoints_x[2]], [extremes_y[1], midpoints_y[2]], color='gray', linestyle='--', linewidth=1) + ax.plot([extremes_x[2], midpoints_x[0]], [extremes_y[2], midpoints_y[0]], color='gray', linestyle='--', linewidth=1) + + # Add perpendicular ticks to each line with adjusted length + add_perpendicular_ticks(extremes_x[0], extremes_y[0], midpoints_x[1], midpoints_y[1], ax=ax) + add_perpendicular_ticks(extremes_x[1], extremes_y[1], midpoints_x[2], midpoints_y[2], ax=ax) + add_perpendicular_ticks(extremes_x[2], extremes_y[2], midpoints_x[0], midpoints_y[0], ax=ax) + + # Scaling factor to move the text farther from the vertices + scaling_factor = 1.1 + + # Calculate the centroid of the triangle + centroid_x = np.mean(extremes_x) + centroid_y = np.mean(extremes_y) + + # Scale the label positions away from the centroid + scaled_extremes_x = centroid_x + scaling_factor * (extremes_x - centroid_x) + scaled_extremes_y = centroid_y + scaling_factor * (extremes_y - centroid_y) + + if component_names: + # Plot the text labels farther away from the triangle + if 'C32.0' in subsites: + component_larynx = mixture_df['C32.0'].argmax() + ax.text(scaled_extremes_x[component_larynx], scaled_extremes_y[component_larynx], "Larynx like", + fontsize=10, ha='right', va='top', c=COLORS['orange']) + + if 'C01' in subsites: + component_oropharynx = mixture_df['C01'].argmax() + ax.text(scaled_extremes_x[component_oropharynx], scaled_extremes_y[component_oropharynx], "Oropharynx like", + fontsize=10, ha='left', va='top', c=COLORS['green']) - plt.gca().set_axis_off() - plt.show() + if 'C03' in subsites: + component_oral_cavity = mixture_df['C03'].argmax() + ax.text(scaled_extremes_x[component_oral_cavity], scaled_extremes_y[component_oral_cavity], "Oral cavity like", + fontsize=10, ha='left', va='top', c=COLORS['blue']) + if 'C13' in subsites: + component_hypopharynx = mixture_df['C13'].argmax() + ax.text(scaled_extremes_x[component_hypopharynx], scaled_extremes_y[component_hypopharynx], "Hypopharynx like", + fontsize=10, ha='left', va='top', c=COLORS['red']) + save_figure(args.output, fig, formats=["png", "svg"]) + logger.info(f"Simplex plot saved") def main(args: argparse.Namespace): mixture_df = pd.read_csv(args.input) @@ -134,7 +265,7 @@ def main(args: argparse.Namespace): params = load_yaml_params(args.params) data = params['general']['data'] if nr_components == 2: - plot_2d_simplex() + plot_2d_simplex(mixture_df, data) elif nr_components == 3: plot_3d_simplex(mixture_df, data) else: diff --git a/lyscripts/plot/utils.py b/lyscripts/plot/utils.py index 915c21e..3d6cee5 100644 --- a/lyscripts/plot/utils.py +++ b/lyscripts/plot/utils.py @@ -11,6 +11,7 @@ import h5py import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap import numpy as np import scipy as sp import math @@ -35,6 +36,31 @@ } COLOR_CYCLE = cycle(COLORS.values()) CM_PER_INCH = 2.54 +blue_to_white = LinearSegmentedColormap.from_list("blue to white", + [COLORS['blue'], "#ffffff"], + N=256) +green_to_white = LinearSegmentedColormap.from_list("green_to_white", + [COLORS['green'], "#ffffff"], + N=256) +red_to_white = LinearSegmentedColormap.from_list("red_to_white", + [COLORS['red'], "#ffffff"], + N=256) +orange_to_white = LinearSegmentedColormap.from_list("orange_to_white", + [COLORS['orange'], "#ffffff"], + N=256) +SUBSITE_COLORS = {'C03': blue_to_white(0), + 'C04':blue_to_white(0.15), + 'C06':blue_to_white(0.3), + 'C02':blue_to_white(0.45), + 'C05':blue_to_white(0.6), + 'C10':green_to_white(0), + 'C09':green_to_white(0.3), + 'C01':green_to_white(0.6), + 'C12':red_to_white(0), + 'C13':red_to_white(0.5), + 'C32.1':orange_to_white(0), + 'C32.2':orange_to_white(0.3), + 'C32.0':orange_to_white(0.6)} def floor_at_decimal(value: float, decimal: int) -> float: From c5d4aaeb0ea95e245e12a6d43853a5014e928d2d Mon Sep 17 00:00:00 2001 From: YoelPH Date: Fri, 25 Oct 2024 09:15:21 +0200 Subject: [PATCH 08/20] change: Updated storing of results and plotting --- lyscripts/mixture_fit.py | 20 ++++++++++++++++---- lyscripts/plot/mixture_plot.py | 4 +++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index e471cc9..1c2362a 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -73,6 +73,11 @@ def _add_arguments(parser: argparse.ArgumentParser): "-s", "--seed", type=int, default=42, help="Seed value to reproduce the same sampling round." ) + parser.add_argument( + "-sp", "--starting_point", type=bool, default=None, + help="Starting point for optimization if we do not want to start from a random point" + ) + parser.set_defaults(run_main=main) @@ -98,7 +103,7 @@ def check_convergence(params_history, likelihood_history, steps_back_list, absol return False -def run_EM(tolerance): +def run_EM(tolerance, history_dir = None): """Run the EM algorithm to determine the optimal parameters. """ is_converged = False @@ -120,7 +125,13 @@ def run_EM(tolerance): # Append current params and likelihood to history params_history.append(params.copy()) likelihood_history.append(MIXTURE.likelihood(use_complete = False)) - + if history_dir != None: + llh_history = pd.DataFrame(likelihood_history) + llh_history.columns = ['likelihoods'] + llh_history.to_csv(history_dir + '/llh.csv', index=False) + param_history = pd.DataFrame(params_history) + param_history.to_csv(history_dir + '/params.csv', index=False) + MIXTURE.get_mixture_coefs().to_csv(history_dir + '/mixture_coef.csv', index=False) # Check if converged if iteration >= 3: # Ensure enough history is available is_converged = check_convergence(params_history, likelihood_history,list(range(1,look_back_steps+1)),tolerance) @@ -134,6 +145,7 @@ def main(args: argparse.Namespace) -> None: params = load_yaml_params(args.params) inference_data = load_patient_data(args.input) + # ugly, but necessary for pickling global MIXTURE MIXTURE = create_mixture(params) @@ -152,10 +164,10 @@ def main(args: argparse.Namespace) -> None: MIXTURE.set_params(**starting_values) MIXTURE.normalize_mixture_coefs() tolerance = params['model'].get('likelihood_tolerance', 0.01) - params_history, likelihood_history = run_EM(tolerance = tolerance) - history_dir = params['general']['history_dir'] logger.info(f"Saving history to {history_dir}.") + params_history, likelihood_history = run_EM(tolerance = tolerance, history_dir = history_dir) + llh_history = pd.DataFrame(likelihood_history) llh_history.columns = ['likelihoods'] llh_history.to_csv(history_dir + '/llh.csv', index=False) diff --git a/lyscripts/plot/mixture_plot.py b/lyscripts/plot/mixture_plot.py index e25098f..1e92456 100644 --- a/lyscripts/plot/mixture_plot.py +++ b/lyscripts/plot/mixture_plot.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap -from lyscripts.plot.utils import COLORS, save_figure +from lyscripts.plot.utils import COLORS, save_figure, SUBSITE_COLORS logger = logging.getLogger(__name__) @@ -40,6 +40,8 @@ def _add_arguments(parser: argparse.ArgumentParser): def main(args: argparse.Namespace): tmp = LinearSegmentedColormap.from_list("tmp", [COLORS['green'], COLORS['red']], N=128) mixture_df = pd.read_csv(args.input) + filtered_keys = [key for key in SUBSITE_COLORS if key in mixture_df.columns] + mixture_df = mixture_df[filtered_keys] # Transpose the matrix to rotate by 90° matrix_rotated = mixture_df.T From f108a09c64dc0b1cf8132c9466ade244ddc1f493 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Thu, 7 Nov 2024 09:24:26 +0100 Subject: [PATCH 09/20] made simplex plotting functional for 3 components --- lyscripts/plot/__init__.py | 3 ++- lyscripts/plot/__main__.py | 3 ++- lyscripts/plot/simplex_plot.py | 4 ---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/lyscripts/plot/__init__.py b/lyscripts/plot/__init__.py index e762930..510c287 100644 --- a/lyscripts/plot/__init__.py +++ b/lyscripts/plot/__init__.py @@ -6,7 +6,7 @@ import argparse from pathlib import Path -from lyscripts.plot import corner, histograms, thermo_int, mixture_plot +from lyscripts.plot import corner, histograms, thermo_int, mixture_plot, simplex_plot def _add_parser( @@ -25,4 +25,5 @@ def _add_parser( histograms._add_parser(subparsers, help_formatter=parser.formatter_class) thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) + simplex_plot._add_parser(subparsers, help_formatter=parser.formatter_class) diff --git a/lyscripts/plot/__main__.py b/lyscripts/plot/__main__.py index 8c5cf81..a5c802b 100644 --- a/lyscripts/plot/__main__.py +++ b/lyscripts/plot/__main__.py @@ -1,7 +1,7 @@ import argparse from lyscripts import RichDefaultHelpFormatter, exit_cli -from lyscripts.plot import corner, histograms, thermo_int, mixture_plot +from lyscripts.plot import corner, histograms, thermo_int, mixture_plot, simplex_plot # I need another __main__ guard here, because otherwise pdoc tries to run this if __name__ == "__main__": @@ -19,6 +19,7 @@ histograms._add_parser(subparsers, help_formatter=parser.formatter_class) thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) + simplex_plot._add_parser(subparsers, help_formatter=parser.formatter_class) args = parser.parse_args() args.run_main(args) diff --git a/lyscripts/plot/simplex_plot.py b/lyscripts/plot/simplex_plot.py index 1ac8223..db9f895 100644 --- a/lyscripts/plot/simplex_plot.py +++ b/lyscripts/plot/simplex_plot.py @@ -55,12 +55,8 @@ def _add_arguments(parser: argparse.ArgumentParser): import h5py import matplotlib.pyplot as plt -from tueplots import figsizes, fontsizes from lyscripts.plot.utils import COLORS as USZ -from helpers import generate_location_colors - - def create_parser() -> argparse.ArgumentParser: """Assemble the parser for the command line arguments.""" parser = argparse.ArgumentParser(description=__doc__) From d5f99509e156a08ddf714bda784e3863b822fdf9 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Thu, 7 Nov 2024 09:28:04 +0100 Subject: [PATCH 10/20] small update --- lyscripts/plot/simplex_plot.py | 92 +++++++++++++++++----------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/lyscripts/plot/simplex_plot.py b/lyscripts/plot/simplex_plot.py index db9f895..1d0dc30 100644 --- a/lyscripts/plot/simplex_plot.py +++ b/lyscripts/plot/simplex_plot.py @@ -78,52 +78,52 @@ def create_parser() -> argparse.ArgumentParser: return parser -def plot_2d_simplex(mixture_df, data): - _, bottom_ax = plt.subplots() - subsites = list(mixture_df.columns) - cluster_x = mixture_df.loc[0] - cluster_y = [0. for _ in subsites] - annotations = [f"{label}\n({num})" for label, num in num_patients.items()] - bottom_ax.scatter( - cluster_x, cluster_y, - s=[num for num in num_patients.values()], - c=list(generate_location_colors(subsites)), - alpha=0.7, - linewidths=0., - zorder=10, - ) - - sorted_idx = cluster_components.argsort() - sorted_x = cluster_components[sorted_idx] - sorted_annotations = [annotations[i] for i in sorted_idx] - sorted_num = [list(num_patients.values())[i] for i in sorted_idx] - for i, (x, num, annotation) in enumerate(zip(sorted_x, sorted_num, sorted_annotations)): - bottom_ax.annotate( - annotation, - # sqrt, because marker's area grows linearly with patient num, not radius - xy=(x, np.sqrt(0.0000003 * num) * (- 1)**i), - xytext=(x, 0.025 * (- 1)**i), - ha="center", - va="bottom" if i % 2 == 0 else "top", - fontsize="small", - arrowprops={ - "arrowstyle": "-", - "color": USZ["gray"], - "linewidth": 1., - } - ) - - bottom_ax.set_xlabel("assignment to component A") - bottom_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) - top_ax = bottom_ax.secondary_xaxis( - location="top", - functions=(lambda x: 1. - x, lambda x: 1. - x), - ) - top_ax.set_xlabel("assignment to component B") - top_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) - bottom_ax.set_yticks([]) - bottom_ax.grid(axis="x", alpha=0.5, color=USZ["gray"], linestyle=":") - plt.savefig(args.output, bbox_inches="tight", dpi=300) +# def plot_2d_simplex(mixture_df, data): +# _, bottom_ax = plt.subplots() +# subsites = list(mixture_df.columns) +# cluster_x = mixture_df.loc[0] +# cluster_y = [0. for _ in subsites] +# annotations = [f"{label}\n({num})" for label, num in num_patients.items()] +# bottom_ax.scatter( +# cluster_x, cluster_y, +# s=[num for num in num_patients.values()], +# c=list(generate_location_colors(subsites)), +# alpha=0.7, +# linewidths=0., +# zorder=10, +# ) + +# sorted_idx = cluster_components.argsort() +# sorted_x = cluster_components[sorted_idx] +# sorted_annotations = [annotations[i] for i in sorted_idx] +# sorted_num = [list(num_patients.values())[i] for i in sorted_idx] +# for i, (x, num, annotation) in enumerate(zip(sorted_x, sorted_num, sorted_annotations)): +# bottom_ax.annotate( +# annotation, +# # sqrt, because marker's area grows linearly with patient num, not radius +# xy=(x, np.sqrt(0.0000003 * num) * (- 1)**i), +# xytext=(x, 0.025 * (- 1)**i), +# ha="center", +# va="bottom" if i % 2 == 0 else "top", +# fontsize="small", +# arrowprops={ +# "arrowstyle": "-", +# "color": USZ["gray"], +# "linewidth": 1., +# } +# ) + +# bottom_ax.set_xlabel("assignment to component A") +# bottom_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) +# top_ax = bottom_ax.secondary_xaxis( +# location="top", +# functions=(lambda x: 1. - x, lambda x: 1. - x), +# ) +# top_ax.set_xlabel("assignment to component B") +# top_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) +# bottom_ax.set_yticks([]) +# bottom_ax.grid(axis="x", alpha=0.5, color=USZ["gray"], linestyle=":") +# plt.savefig(args.output, bbox_inches="tight", dpi=300) # Function to add perpendicular ticks as short lines def add_perpendicular_ticks(x1, y1, x2, y2, tick_length=0.01): From 305d6ebace87edd92556d539eb4b6d738b05833d Mon Sep 17 00:00:00 2001 From: YoelPH Date: Tue, 10 Dec 2024 12:03:38 +0100 Subject: [PATCH 11/20] added sampling functions --- lyscripts/mixture_fit.py | 2 - lyscripts/mixture_sample.py | 123 ++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 lyscripts/mixture_sample.py diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index 1c2362a..55950c5 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -140,8 +140,6 @@ def run_EM(tolerance, history_dir = None): def main(args: argparse.Namespace) -> None: """Main function to run the EM algorithm for a mixture model""" - # as recommended in https://emcee.readthedocs.io/en/stable/tutorials/parallel/# - os.environ["OMP_NUM_THREADS"] = "1" params = load_yaml_params(args.params) inference_data = load_patient_data(args.input) diff --git a/lyscripts/mixture_sample.py b/lyscripts/mixture_sample.py new file mode 100644 index 0000000..778717e --- /dev/null +++ b/lyscripts/mixture_sample.py @@ -0,0 +1,123 @@ +""" +Learn the spread probabilities of the HMM for lymphatic tumor progression using +the preprocessed data as input and the mixture model. +""" +# pylint: disable=logging-fstring-interpolation +import argparse +import logging +import os +from collections import namedtuple + +try: + from multiprocess import Pool +except ModuleNotFoundError: + from multiprocessing import Pool + +from pathlib import Path + +import emcee +import numpy as np +import pandas as pd +from lymph import models +from lymixture import LymphMixture +from lymixture.em import sample_fixed_mixture, sample_model_params, _set_params, _get_params +from rich.progress import Progress, TimeElapsedColumn, track + + +from lyscripts.utils import ( + create_mixture, + load_patient_data, + load_yaml_params, + to_numpy, + assign_modalities +) + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to a ``subparsers`` instance and run its main function when chosen. + + This is called by the parent module that is called via the command line. + """ + parser.add_argument( + "-m", "--mixture_coefs", type=Path, required=True, + help="File path of mixture coefficients" + ) + parser.add_argument( + "-p", "--params", default="./params.yaml", type=Path, + help="Path to parameter file." + ) + parser.add_argument( + "-mp", "--model_params", type=Path, required = True, + help="File path of mixture coefficients" + ) + parser.add_argument( + "--mode", type=str, default = "fixed_mixture", + help = "Mode of sampling. Use either 'fixed_mixture' or 'fixed_latent'" + ) + parser.add_argument( + "-o", "--output", type=Path, + help="Output path for samples" + ) + + + parser.set_defaults(run_main=main) + + +MIXTURE = None + +def log_prob_fn() -> float: + """log probability function using global variables because of pickling.""" + return MIXTURE.likelihood(use_complete = True, given_resps = MIXTURE.get_resps(norm = True)) + +def main(args: argparse.Namespace) -> None: + """Main function to sample parameters for a mixture model""" + + params = load_yaml_params(args.params) + model_params = pd.read_csv(args.model_params,header = [0]) + mixture_df = pd.read_csv(args.mixtures_coefs) + inference_data = load_patient_data(args.input) + param_dict = dict(model_params.iloc[-1]) + # ugly, but necessary for pickling + global MIXTURE + MIXTURE = create_mixture(params) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE.components[0], models.Unilateral): + side = params["model"].get("side", "ipsi") + MIXTURE.load_patient_data(inference_data, split_by= params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) + assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) + + else: + raise "Only Unilateral has been implemented so far" + + + MIXTURE.set_params(**param_dict) + MIXTURE.set_resps(mixture_df) + if args.mode == "fixed_mixture": + backend, samples = sample_fixed_mixture(MIXTURE, steps = params['sampling'].get('steps'),filename = args.output+"fixed_mixture") + elif args.mode == "fixed_latent": + backend, samples = sample_model_params(MIXTURE, steps = params['sampling'].get('steps'),filename = args.output+"fixed_latent") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) From e909467e149d8ec61dfe3415166db71a1fa0da47 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Thu, 12 Dec 2024 09:00:07 +0100 Subject: [PATCH 12/20] fix: fixed sampling scripts --- lyscripts/__init__.py | 3 ++- lyscripts/mixture_sample.py | 21 ++++++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/lyscripts/__init__.py b/lyscripts/__init__.py index 44f6355..0d2dd98 100644 --- a/lyscripts/__init__.py +++ b/lyscripts/__init__.py @@ -13,7 +13,7 @@ import rich from rich_argparse import RichHelpFormatter -from lyscripts import plot, app, compute, data, evaluate, mixture_fit, temp_schedule +from lyscripts import plot, app, compute, data, evaluate, mixture_fit, temp_schedule, mixture_sample from lyscripts._version import version from lyscripts.utils import CustomRichHandler, console @@ -116,6 +116,7 @@ def main(): plot._add_parser(subparsers, help_formatter=parser.formatter_class) mixture_fit._add_parser(subparsers, help_formatter=parser.formatter_class) temp_schedule._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_sample._add_parser(subparsers, help_formatter=parser.formatter_class) args = parser.parse_args() diff --git a/lyscripts/mixture_sample.py b/lyscripts/mixture_sample.py index 778717e..c4dbef1 100644 --- a/lyscripts/mixture_sample.py +++ b/lyscripts/mixture_sample.py @@ -20,7 +20,7 @@ import pandas as pd from lymph import models from lymixture import LymphMixture -from lymixture.em import sample_fixed_mixture, sample_model_params, _set_params, _get_params +from lymixture.em import sample_fixed_mixture, sample_model_params, expectation from rich.progress import Progress, TimeElapsedColumn, track @@ -74,6 +74,14 @@ def _add_arguments(parser: argparse.ArgumentParser): "-o", "--output", type=Path, help="Output path for samples" ) + parser.add_argument( + "-d", "--data", type=Path, required=True, + help="Path to the data file." + ) + parser.add_argument( + "-c", "--continue_sampling", type=bool, default = False, + help="Continue sampling from previous run stored in the output backend" + ) parser.set_defaults(run_main=main) @@ -90,8 +98,8 @@ def main(args: argparse.Namespace) -> None: params = load_yaml_params(args.params) model_params = pd.read_csv(args.model_params,header = [0]) - mixture_df = pd.read_csv(args.mixtures_coefs) - inference_data = load_patient_data(args.input) + mixture_df = pd.read_csv(args.mixture_coefs) + inference_data = load_patient_data(args.data) param_dict = dict(model_params.iloc[-1]) # ugly, but necessary for pickling global MIXTURE @@ -106,13 +114,12 @@ def main(args: argparse.Namespace) -> None: else: raise "Only Unilateral has been implemented so far" - MIXTURE.set_params(**param_dict) - MIXTURE.set_resps(mixture_df) + MIXTURE.set_resps(expectation(MIXTURE, param_dict)) if args.mode == "fixed_mixture": - backend, samples = sample_fixed_mixture(MIXTURE, steps = params['sampling'].get('steps'),filename = args.output+"fixed_mixture") + backend, samples = sample_fixed_mixture(MIXTURE, steps = params["sampling"].get("steps",),filename = str(args.output)+"/fixed_mixture.hdf5", continue_sampling = args.continue_sampling) elif args.mode == "fixed_latent": - backend, samples = sample_model_params(MIXTURE, steps = params['sampling'].get('steps'),filename = args.output+"fixed_latent") + backend, samples = sample_model_params(MIXTURE, steps = params["sampling"].get("steps"),filename = str(args.output)+"/fixed_latent.hdf5", continue_sampling = args.continue_sampling) if __name__ == "__main__": From 770be2f1128a13f9dc7ad3943ee4871cfc6e3091 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Tue, 17 Dec 2024 12:06:52 +0100 Subject: [PATCH 13/20] change: added functional sampling and plotting --- lyscripts/mixture_sample.py | 5 - lyscripts/plot/__init__.py | 5 +- lyscripts/plot/__main__.py | 3 +- lyscripts/plot/mixture_sampling_plotter.py | 257 +++++++++++++++++++++ lyscripts/plot/utils.py | 2 + pyproject.toml | 2 +- 6 files changed, 264 insertions(+), 10 deletions(-) create mode 100644 lyscripts/plot/mixture_sampling_plotter.py diff --git a/lyscripts/mixture_sample.py b/lyscripts/mixture_sample.py index c4dbef1..f30041c 100644 --- a/lyscripts/mixture_sample.py +++ b/lyscripts/mixture_sample.py @@ -15,13 +15,9 @@ from pathlib import Path -import emcee -import numpy as np import pandas as pd from lymph import models -from lymixture import LymphMixture from lymixture.em import sample_fixed_mixture, sample_model_params, expectation -from rich.progress import Progress, TimeElapsedColumn, track from lyscripts.utils import ( @@ -98,7 +94,6 @@ def main(args: argparse.Namespace) -> None: params = load_yaml_params(args.params) model_params = pd.read_csv(args.model_params,header = [0]) - mixture_df = pd.read_csv(args.mixture_coefs) inference_data = load_patient_data(args.data) param_dict = dict(model_params.iloc[-1]) # ugly, but necessary for pickling diff --git a/lyscripts/plot/__init__.py b/lyscripts/plot/__init__.py index 510c287..2f9a238 100644 --- a/lyscripts/plot/__init__.py +++ b/lyscripts/plot/__init__.py @@ -6,8 +6,7 @@ import argparse from pathlib import Path -from lyscripts.plot import corner, histograms, thermo_int, mixture_plot, simplex_plot - +from lyscripts.plot import corner, histograms, thermo_int, mixture_plot, simplex_plot, mixture_sampling_plotter def _add_parser( subparsers: argparse._SubParsersAction, @@ -26,4 +25,4 @@ def _add_parser( thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) simplex_plot._add_parser(subparsers, help_formatter=parser.formatter_class) - + mixture_sampling_plotter._add_parser(subparsers, help_formatter=parser.formatter_class) diff --git a/lyscripts/plot/__main__.py b/lyscripts/plot/__main__.py index a5c802b..81c8e26 100644 --- a/lyscripts/plot/__main__.py +++ b/lyscripts/plot/__main__.py @@ -1,7 +1,7 @@ import argparse from lyscripts import RichDefaultHelpFormatter, exit_cli -from lyscripts.plot import corner, histograms, thermo_int, mixture_plot, simplex_plot +from lyscripts.plot import corner, histograms, thermo_int, mixture_plot, simplex_plot, mixture_sampling_plotter # I need another __main__ guard here, because otherwise pdoc tries to run this if __name__ == "__main__": @@ -20,6 +20,7 @@ thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) simplex_plot._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_sampling_plotter._add_parser(subparsers, help_formatter=parser.formatter_class) args = parser.parse_args() args.run_main(args) diff --git a/lyscripts/plot/mixture_sampling_plotter.py b/lyscripts/plot/mixture_sampling_plotter.py new file mode 100644 index 0000000..aedb5aa --- /dev/null +++ b/lyscripts/plot/mixture_sampling_plotter.py @@ -0,0 +1,257 @@ +import argparse +import logging +from pathlib import Path + +from cycler import cycler +import scipy as sp +import numpy as np +import emcee +from lymph import models +import matplotlib.pyplot as plt +import pandas as pd + +from lyscripts.plot.utils import COLORS, save_figure +from lymixture.em import _set_params, expectation + +from lyscripts.utils import ( + create_mixture, + load_patient_data, + load_yaml_params, + assign_modalities +) + +logger = logging.getLogger(__name__) + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument( + "--input", type=Path, + help="File path with emcee backend of samples" + ) + parser.add_argument( + "--output", type=Path, + help="Output path for the plot" + ) + parser.add_argument( + "-m", "--mode", type=str, default = "fixed_mixture", + help = "Mode of sampling. Use either 'fixed_mixture' or 'fixed_latent'" + ) + parser.add_argument( + "-s", "--size", type=int, default = 200, + help = "Number of samples to be used for plotting" + ) + parser.add_argument( + "-p", "--params", default="./params.yaml", type=Path, + help="Path to parameter file." + ) + parser.add_argument( + "-d", "--data", type=Path, required=True, + help="Path to patient data." + ) + parser.add_argument( + "-mp", "--model_params", type=Path, required = True, + help="File path of mixture coefficients" + ) + + parser.set_defaults(run_main=main) + +def multiple_plotter(dataset, risk_dictionary_extended, subsite, stage = ''): + hist_cycl = ( + cycler(histtype=["stepfilled", "step"]) + * cycler(color=list(COLORS.values())) + ) + line_cycl = ( + cycler(linestyle=["-", "--"]) + * cycler(color=list(COLORS.values())) + ) + dataset_staging = dataset.copy() + dataset_staging['tumor','1','t_stage'] = dataset_staging['tumor','1','t_stage'].replace([0,1,2], 'early') + dataset_staging['tumor','1','t_stage'] = dataset_staging['tumor','1','t_stage'].replace([3,4], 'late') + if stage == 'early' or stage == 'late': + data_selected = dataset_staging.loc[(dataset_staging['tumor']['1']['subsite'] == subsite) & (dataset_staging['tumor']['1']['t_stage'] == stage)] + else: + data_selected = dataset_staging.loc[dataset_staging['tumor']['1']['subsite'] == subsite] + min_value = 0 + max_value = 100 + prevalence = {} + number_of_patients = {} + risks = {} + risk_dictionary = risk_dictionary_extended[subsite] + for key in risk_dictionary.keys(): + prevalence[key] = (data_selected['max_llh']['ipsi'][key] == True).sum() + number_of_patients[key] = len(data_selected) + risks[key] = np.array(risk_dictionary[key])*100 + + num_matches = [prevalence[key] for key in risk_dictionary.keys()] + num_totals = [number_of_patients[key] for key in risk_dictionary.keys()] + values = [risks[key] for key in risk_dictionary.keys()] + hist_kwargs = { + "bins": np.linspace(min_value, max_value, 80), + "density": True, + "alpha": 0.6, + "linewidth": 2., + } + fig, ax = plt.subplots(figsize=(12,4)) + + x = np.linspace(min_value, max_value, 200) + zipper = zip(values, risk_dictionary.keys(), num_matches, num_totals, hist_cycl, line_cycl) + for vals, label, a, n, hstyle, lstyle in zipper: + ax.hist( + vals, + label=label, + **hist_kwargs, + **hstyle + ) + if not np.isnan(a): + post = sp.stats.beta.pdf(x / 100., a+1, n-a+1) / 100. + ax.plot(x, post, label=f"{int(a)}/{int(n)}", **lstyle) + ax.legend() + ax.set_xlabel("probability [%]") + fig.suptitle(f"Risk distributions {stage} for subsite {subsite}",fontsize = 16) + return fig + +def multiple_plotter_component(risk_dictionary_extended, component, stage = ''): + hist_cycl = ( + cycler(histtype=["stepfilled", "step"]) + * cycler(color=list(COLORS.values())) + ) + line_cycl = ( + cycler(linestyle=["-", "--"]) + * cycler(color=list(COLORS.values())) + ) + min_value = 0 + max_value = 100 + prevalence = {} + number_of_patients = {} + risks = {} + risk_dictionary = risk_dictionary_extended[component] + for key in risk_dictionary.keys(): + risks[key] = np.array(risk_dictionary[key])*100 + + values = [risks[key] for key in risk_dictionary.keys()] + hist_kwargs = { + "bins": np.linspace(min_value, max_value, 80), + "density": True, + "alpha": 0.6, + "linewidth": 2., + } + fig, ax = plt.subplots(figsize=(12,4)) + + x = np.linspace(min_value, max_value, 200) + zipper = zip(values, risk_dictionary.keys(), hist_cycl) + for vals, label, hstyle in zipper: + ax.hist( + vals, + label=label, + **hist_kwargs, + **hstyle + ) + ax.legend() + ax.set_xlabel("probability [%]") + fig.suptitle(f"Risk distributions {stage} for component {component}",fontsize = 16) + return fig + + +def main(args: argparse.Namespace): + params = load_yaml_params(args.params) + inference_data = load_patient_data(args.data) + backend = emcee.backends.HDFBackend(args.input) + samples = backend.get_chain(flat = True) + model_params = pd.read_csv(args.model_params,header = [0]) + + # ugly, but necessary for pickling + global MIXTURE + MIXTURE = create_mixture(params) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE.components[0], models.Unilateral): + side = params["model"].get("side", "ipsi") + MIXTURE.load_patient_data(inference_data, split_by= params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) + assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) + + else: + raise "Only Unilateral has been implemented so far" + + param_dict = dict(model_params.iloc[-1]) + MIXTURE.set_params(**param_dict) + MIXTURE.set_resps(expectation(MIXTURE, param_dict)) + + print(MIXTURE.get_params()) + lnls = list(MIXTURE.components[0].graph.lnls.keys()) + component_list = list(range(len(MIXTURE.components))) + component_dictionary_early_full_sampling = {} + component_dictionary_late_full_sampling = {} + + for component in component_list: + component_dictionary_early_full_sampling[str(component)] = { + lnl: [] for lnl in lnls + } + component_dictionary_late_full_sampling[str(component)] = { + lnl: [] for lnl in lnls + } + + subsite_dictionary_early_full_sampling = {} + subsite_dictionary_late_full_sampling = {} + subsite_list = list(MIXTURE.subgroups.keys()) + + for subsite in subsite_list: + subsite_dictionary_early_full_sampling[subsite] = { + lnl: [] for lnl in lnls # Dynamically generate keys from the lnls list + } + subsite_dictionary_late_full_sampling[subsite] = { + lnl: [] for lnl in lnls + } + + involvement_dict = {lnl: {lnl: True} for lnl in lnls} + samples_thinned = samples[::int(np.round(len(samples)/args.size,0))] + for round, sample in enumerate(samples_thinned): + if args.mode == "fixed_latent": + MIXTURE.set_params(*sample) + elif args.mode == "fixed_mixture": + _set_params(MIXTURE, sample) + else: + raise ValueError("Invalid mode") + component_dictionary_early_full_sampling['0']['II'].append(MIXTURE.components[0].risk(involvement = involvement_dict['II'],t_stage = 'early')) + for component in component_list: + for lnl in lnls: + component_dictionary_early_full_sampling[str(component)][lnl].append(MIXTURE.components[component].risk(involvement = involvement_dict[lnl],t_stage = 'early')) + component_dictionary_late_full_sampling[str(component)][lnl].append(MIXTURE.components[component].risk(involvement = involvement_dict[lnl],t_stage = 'late')) + + for subsite in subsite_list: + for lnl in lnls: + subsite_dictionary_early_full_sampling[subsite][lnl].append(MIXTURE.risk(subgroup = subsite, involvement = involvement_dict[lnl],t_stage = 'early')) + subsite_dictionary_late_full_sampling[subsite][lnl].append(MIXTURE.risk(subgroup = subsite, involvement = involvement_dict[lnl],t_stage = 'late')) + print(round, ' done') + print(inference_data) + for component in component_list: + fig = multiple_plotter_component(component_dictionary_early_full_sampling, str(component), stage = 'early') + save_figure(args.output/f"component_{component}_early", fig, formats = ['png','svg']) + fig = multiple_plotter_component(component_dictionary_late_full_sampling, str(component), stage = 'late') + save_figure(args.output/f"component_{component}_late", fig, formats = ['png','svg']) + for subsite in subsite_list: + fig = multiple_plotter(inference_data, subsite_dictionary_early_full_sampling, subsite, stage = 'early') + save_figure(args.output/f"{subsite}_early", fig, formats = ['png','svg']) + fig = multiple_plotter(inference_data,subsite_dictionary_late_full_sampling, subsite, stage = 'late') + save_figure(args.output/f"{subsite}_late", fig, formats = ['png','svg']) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) + diff --git a/lyscripts/plot/utils.py b/lyscripts/plot/utils.py index 3d6cee5..117ee21 100644 --- a/lyscripts/plot/utils.py +++ b/lyscripts/plot/utils.py @@ -33,6 +33,8 @@ "green": "#00afa5", "red": "#ae0060", "gray": "#c5d5db", + "light_blue": "#00A8D8", + "dark_grey_experimental": "#404756", } COLOR_CYCLE = cycle(COLORS.values()) CM_PER_INCH = 2.54 diff --git a/pyproject.toml b/pyproject.toml index 4b6ad6a..2667905 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ app = [ lyscripts = "lyscripts:main" [tool.setuptools.packages.find] -include = ["lyscripts"] +include = ["lyscripts","lyscripts.*"] [tool.setuptools_scm] write_to = "lyscripts/_version.py" From 6c5c6117b89003dd42f01e2d028bdd864a304412 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Mon, 23 Dec 2024 09:45:36 +0100 Subject: [PATCH 14/20] modifications to match lymixference --- lyscripts/plot/simplex_plot.py | 53 +++++++++------------------------- lyscripts/plot/utils.py | 26 ++++++++++++++++- 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/lyscripts/plot/simplex_plot.py b/lyscripts/plot/simplex_plot.py index 1d0dc30..61d3a50 100644 --- a/lyscripts/plot/simplex_plot.py +++ b/lyscripts/plot/simplex_plot.py @@ -1,3 +1,8 @@ +""" +Visualize the component assignments of the trained mixture model. +""" + + import argparse import logging import numpy as np @@ -10,6 +15,7 @@ from lyscripts.plot.utils import COLORS, SUBSITE_COLORS, save_figure, p_to_xyz, add_perpendicular_crosses_3d from lyscripts.utils import load_yaml_params +from matplotlib.ticker import StrMethodFormatter logger = logging.getLogger(__name__) @@ -44,40 +50,6 @@ def _add_arguments(parser: argparse.ArgumentParser): parser.set_defaults(run_main=main) -""" -Visualize the component assignments of the trained mixture model. -""" -import argparse -from pathlib import Path -from matplotlib.ticker import StrMethodFormatter -import numpy as np -import yaml - -import h5py -import matplotlib.pyplot as plt -from lyscripts.plot.utils import COLORS as USZ - -def create_parser() -> argparse.ArgumentParser: - """Assemble the parser for the command line arguments.""" - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "-m", "--model", type=Path, default="models/mixture.hdf5", - help=( - "Path to the model HDF5 file. Needs to contain a dataset called " - "``em/cluster_assignments``." - ) - ) - parser.add_argument( - "-o", "--output", type=Path, default="figures/cluster_assignments.png", - help="Path to the output file.", - ) - parser.add_argument( - "-p", "--params", type=Path, default="_variables.yml", - help="Path to the parameter file..", - ) - return parser - - # def plot_2d_simplex(mixture_df, data): # _, bottom_ax = plt.subplots() # subsites = list(mixture_df.columns) @@ -150,7 +122,8 @@ def add_perpendicular_ticks(x1, y1, x2, y2, tick_length=0.01): plt.plot([x_tick - tick_length * perp_dx, x_tick + tick_length * perp_dx], [y_tick - tick_length * perp_dy, y_tick + tick_length * perp_dy], color='gray', linewidth=0.8) plt.text(x_tick, y_tick, f'{int(100 - t * 100)}%', fontsize=8, ha='right', va='bottom') -def plot_3d_simplex(mixture_df, data, component_names = False): +def plot_3d_simplex(mixture_df, data, output, component_names = False): + data = pd.read_csv(data, header=[0, 1, 2]) subsites = list(mixture_df.columns) colors_ordered = [SUBSITE_COLORS[subsite] for subsite in subsites] @@ -215,9 +188,9 @@ def plot_3d_simplex(mixture_df, data, component_names = False): ax.plot([extremes_x[2], midpoints_x[0]], [extremes_y[2], midpoints_y[0]], color='gray', linestyle='--', linewidth=1) # Add perpendicular ticks to each line with adjusted length - add_perpendicular_ticks(extremes_x[0], extremes_y[0], midpoints_x[1], midpoints_y[1], ax=ax) - add_perpendicular_ticks(extremes_x[1], extremes_y[1], midpoints_x[2], midpoints_y[2], ax=ax) - add_perpendicular_ticks(extremes_x[2], extremes_y[2], midpoints_x[0], midpoints_y[0], ax=ax) + add_perpendicular_ticks(extremes_x[0], extremes_y[0], midpoints_x[1], midpoints_y[1]) + add_perpendicular_ticks(extremes_x[1], extremes_y[1], midpoints_x[2], midpoints_y[2]) + add_perpendicular_ticks(extremes_x[2], extremes_y[2], midpoints_x[0], midpoints_y[0]) # Scaling factor to move the text farther from the vertices scaling_factor = 1.1 @@ -252,7 +225,7 @@ def plot_3d_simplex(mixture_df, data, component_names = False): ax.text(scaled_extremes_x[component_hypopharynx], scaled_extremes_y[component_hypopharynx], "Hypopharynx like", fontsize=10, ha='left', va='top', c=COLORS['red']) - save_figure(args.output, fig, formats=["png", "svg"]) + save_figure(output, fig, formats=["png", "svg"]) logger.info(f"Simplex plot saved") def main(args: argparse.Namespace): @@ -263,7 +236,7 @@ def main(args: argparse.Namespace): if nr_components == 2: plot_2d_simplex(mixture_df, data) elif nr_components == 3: - plot_3d_simplex(mixture_df, data) + plot_3d_simplex(mixture_df, data, output = args.output) else: logger.info(f"Simplex not supported for {nr_components} components") diff --git a/lyscripts/plot/utils.py b/lyscripts/plot/utils.py index 117ee21..29e72a2 100644 --- a/lyscripts/plot/utils.py +++ b/lyscripts/plot/utils.py @@ -484,4 +484,28 @@ def add_perpendicular_crosses_3d(ax, x1, y1, z1, x2, y2, z2, tick_length=0.03): [y_tick - perp_vec2[1], y_tick + perp_vec2[1]], [z_tick - perp_vec2[2], z_tick + perp_vec2[2]], color='gray', linewidth=0.8) - ax.text(x_tick, y_tick, z_tick, f'{int(100 - t * 100)}%', fontsize=6, ha='right', va='bottom') \ No newline at end of file + ax.text(x_tick, y_tick, z_tick, f'{int(100 - t * 100)}%', fontsize=6, ha='right', va='bottom') + +def add_perpendicular_ticks(x1, y1, x2, y2, tick_length=0.01): + num_ticks = 6 # Number of ticks including 0% and 100% + for i in range(num_ticks): + t = i / (num_ticks - 1) + x_tick = x1 + t * (x2 - x1) + y_tick = y1 + t * (y2 - y1) + + # Vector along the line + dx = x2 - x1 + dy = y2 - y1 + + # Perpendicular vector + perp_dx = -dy + perp_dy = dx + + # Normalize the perpendicular vector + length = np.sqrt(perp_dx**2 + perp_dy**2) + perp_dx /= length + perp_dy /= length + + # Draw tick as a short perpendicular line + plt.plot([x_tick - tick_length * perp_dx, x_tick + tick_length * perp_dx], [y_tick - tick_length * perp_dy, y_tick + tick_length * perp_dy], color='gray', linewidth=0.8) + plt.text(x_tick, y_tick, f'{int(100 - t * 100)}%', fontsize=8, ha='right', va='bottom') \ No newline at end of file From 13d87c1539e55d9edf7b0d2740fdd11da2fdb35e Mon Sep 17 00:00:00 2001 From: YoelPH Date: Mon, 23 Dec 2024 09:49:06 +0100 Subject: [PATCH 15/20] change! simplex plotter. 2d not functional yet --- lyscripts/plot/simplex_plot.py | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/lyscripts/plot/simplex_plot.py b/lyscripts/plot/simplex_plot.py index 61d3a50..6ec977a 100644 --- a/lyscripts/plot/simplex_plot.py +++ b/lyscripts/plot/simplex_plot.py @@ -13,7 +13,7 @@ from matplotlib.lines import Line2D -from lyscripts.plot.utils import COLORS, SUBSITE_COLORS, save_figure, p_to_xyz, add_perpendicular_crosses_3d +from lyscripts.plot.utils import COLORS, SUBSITE_COLORS, save_figure, add_perpendicular_ticks, p_to_xyz, add_perpendicular_crosses_3d from lyscripts.utils import load_yaml_params from matplotlib.ticker import StrMethodFormatter @@ -97,31 +97,6 @@ def _add_arguments(parser: argparse.ArgumentParser): # bottom_ax.grid(axis="x", alpha=0.5, color=USZ["gray"], linestyle=":") # plt.savefig(args.output, bbox_inches="tight", dpi=300) -# Function to add perpendicular ticks as short lines -def add_perpendicular_ticks(x1, y1, x2, y2, tick_length=0.01): - num_ticks = 6 # Number of ticks including 0% and 100% - for i in range(num_ticks): - t = i / (num_ticks - 1) - x_tick = x1 + t * (x2 - x1) - y_tick = y1 + t * (y2 - y1) - - # Vector along the line - dx = x2 - x1 - dy = y2 - y1 - - # Perpendicular vector - perp_dx = -dy - perp_dy = dx - - # Normalize the perpendicular vector - length = np.sqrt(perp_dx**2 + perp_dy**2) - perp_dx /= length - perp_dy /= length - - # Draw tick as a short perpendicular line - plt.plot([x_tick - tick_length * perp_dx, x_tick + tick_length * perp_dx], [y_tick - tick_length * perp_dy, y_tick + tick_length * perp_dy], color='gray', linewidth=0.8) - plt.text(x_tick, y_tick, f'{int(100 - t * 100)}%', fontsize=8, ha='right', va='bottom') - def plot_3d_simplex(mixture_df, data, output, component_names = False): data = pd.read_csv(data, header=[0, 1, 2]) subsites = list(mixture_df.columns) From 5158681b026310bab221ef429744851b3dafee85 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Wed, 5 Mar 2025 13:54:07 +0100 Subject: [PATCH 16/20] change: opens sampling files --- lyscripts/plot/mixture_sampling_plotter.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/lyscripts/plot/mixture_sampling_plotter.py b/lyscripts/plot/mixture_sampling_plotter.py index aedb5aa..5fd7156 100644 --- a/lyscripts/plot/mixture_sampling_plotter.py +++ b/lyscripts/plot/mixture_sampling_plotter.py @@ -9,7 +9,7 @@ from lymph import models import matplotlib.pyplot as plt import pandas as pd - +import json from lyscripts.plot.utils import COLORS, save_figure from lymixture.em import _set_params, expectation @@ -93,7 +93,7 @@ def multiple_plotter(dataset, risk_dictionary_extended, subsite, stage = ''): risk_dictionary = risk_dictionary_extended[subsite] for key in risk_dictionary.keys(): prevalence[key] = (data_selected['max_llh']['ipsi'][key] == True).sum() - number_of_patients[key] = len(data_selected) + number_of_patients[key] = data_selected['max_llh']['ipsi'][key].notna().sum() risks[key] = np.array(risk_dictionary[key])*100 num_matches = [prevalence[key] for key in risk_dictionary.keys()] @@ -236,7 +236,7 @@ def main(args: argparse.Namespace): subsite_dictionary_early_full_sampling[subsite][lnl].append(MIXTURE.risk(subgroup = subsite, involvement = involvement_dict[lnl],t_stage = 'early')) subsite_dictionary_late_full_sampling[subsite][lnl].append(MIXTURE.risk(subgroup = subsite, involvement = involvement_dict[lnl],t_stage = 'late')) print(round, ' done') - print(inference_data) + for component in component_list: fig = multiple_plotter_component(component_dictionary_early_full_sampling, str(component), stage = 'early') save_figure(args.output/f"component_{component}_early", fig, formats = ['png','svg']) @@ -248,6 +248,15 @@ def main(args: argparse.Namespace): fig = multiple_plotter(inference_data,subsite_dictionary_late_full_sampling, subsite, stage = 'late') save_figure(args.output/f"{subsite}_late", fig, formats = ['png','svg']) + # Save dictionary to a JSON file + with open("subsite_early.json", "w") as file: + json.dump(subsite_dictionary_early_full_sampling, file) + with open("subsite_late.json", "w") as file: + json.dump(subsite_dictionary_late_full_sampling, file) + with open("component_early.json", "w") as file: + json.dump(component_dictionary_early_full_sampling, file) + with open("component_late.json", "w") as file: + json.dump(component_dictionary_late_full_sampling, file) if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) _add_arguments(parser) From db47d27f3d20f4c3ef4314b9171b83fcaaeafc01 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Tue, 8 Apr 2025 16:06:57 +0200 Subject: [PATCH 17/20] newest version for mixture sampling --- lyscripts/mixture_fit.py | 4 ++-- lyscripts/mixture_sample.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index 55950c5..d82bec6 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -117,8 +117,8 @@ def run_EM(tolerance, history_dir = None): look_back_steps = 3 while not is_converged: - print('iteration',iteration) - print('likelihood', likelihood_history[-1]) + logger.info(f"Iteration: {iteration}") + logger.info(f"Likelihood: {likelihood_history[-1]}") latent = expectation(MIXTURE, params) params = maximization(MIXTURE, latent) diff --git a/lyscripts/mixture_sample.py b/lyscripts/mixture_sample.py index f30041c..11bb9cd 100644 --- a/lyscripts/mixture_sample.py +++ b/lyscripts/mixture_sample.py @@ -85,10 +85,6 @@ def _add_arguments(parser: argparse.ArgumentParser): MIXTURE = None -def log_prob_fn() -> float: - """log probability function using global variables because of pickling.""" - return MIXTURE.likelihood(use_complete = True, given_resps = MIXTURE.get_resps(norm = True)) - def main(args: argparse.Namespace) -> None: """Main function to sample parameters for a mixture model""" From f91a955b0c0520ebf79c3108ed57969b018e23a7 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Thu, 28 Aug 2025 11:03:14 +0200 Subject: [PATCH 18/20] update to work from home --- lyscripts/mixture_fit.py | 108 ++++++++++++++++++++++++++++++--------- 1 file changed, 83 insertions(+), 25 deletions(-) diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index d82bec6..afa0463 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -6,22 +6,16 @@ import argparse import logging import os -from collections import namedtuple - -try: - from multiprocess import Pool -except ModuleNotFoundError: - from multiprocessing import Pool +import pickle +from concurrent.futures import ProcessPoolExecutor from pathlib import Path -import emcee import numpy as np import pandas as pd from lymph import models -from lymixture import LymphMixture from lymixture.em import expectation, maximization -from rich.progress import Progress, TimeElapsedColumn, track + from lyscripts.utils import ( create_mixture, @@ -57,10 +51,6 @@ def _add_arguments(parser: argparse.ArgumentParser): "-i", "--input", type=Path, required=True, help="Path to training data files" ) - # parser.add_argument( - # "-o", "--output", type=Path, required=True, - # help="Path to the HDF5 file to store the results in" - # ) parser.add_argument( "--history", type=Path, nargs="?", help="Path to store history in (as CSV file)." @@ -73,6 +63,10 @@ def _add_arguments(parser: argparse.ArgumentParser): "-s", "--seed", type=int, default=42, help="Seed value to reproduce the same sampling round." ) + parser.add_argument( + "-m", "--multi_fit", type=bool, default=False, + help="Whether to fit multiple models for later uncertainty evaluation" + ) parser.add_argument( "-sp", "--starting_point", type=bool, default=None, help="Starting point for optimization if we do not want to start from a random point" @@ -138,11 +132,68 @@ def run_EM(tolerance, history_dir = None): iteration += 1 return params_history, likelihood_history + +def process_dataset(dataset, initial_params, folder_path, index = 0, look_back_steps=3): + os.makedirs(folder_path, exist_ok=True) + subpath_optimal_params = 'optimal_params' + os.makedirs(os.path.join(folder_path, subpath_optimal_params), exist_ok=True) + subpath_params_history = 'params_history' + os.makedirs(os.path.join(folder_path, subpath_params_history), exist_ok=True) + subpath_likelihood_history = 'likelihood_history' + os.makedirs(os.path.join(folder_path, subpath_likelihood_history), exist_ok=True) + + logger.info(f"Starting dataset {index}") + mixture = create_mixture(params) + + mixture.load_patient_data( + dataset, + split_by=("tumor", "1", "subsite"), + mapping=lambda x: x, + ) + assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) + mixture.set_params(**initial_params[index]) + params = initial_params[index].copy() + + mixture.normalize_mixture_coefs() + params_history = [params.copy()] + likelihood_history = [mixture.likelihood(use_complete=False)] + + is_converged = False + count = 0 + logger.info(f"[Dataset {index}] started") + while not is_converged: + + latent = expectation(mixture, params, log = True) + mixture.set_resps(np.exp(latent)) + params = maximization(mixture, latent) + + params_history.append(params.copy()) + likelihood_history.append(mixture.likelihood(use_complete=False)) + + llh_history = pd.DataFrame(likelihood_history) + llh_history.columns = ['likelihoods'] + llh_history.to_csv(os.path.join(folder_path, subpath_likelihood_history, f"{file_prefix}_likelihood_history.pkl"), index=False) + param_history = pd.DataFrame(params_history) + param_history.to_csv(os.path.join(folder_path, subpath_params_history, f"{file_prefix}_param_history.pkl"), index=False) + if count >= look_back_steps: + is_converged = check_convergence(params_history, likelihood_history, list(range(1, look_back_steps + 1))) + + count += 1 + + logger.info(f"[Dataset {index}] Converged after {count} steps") + file_prefix = f"dataset_{index}" + + with open(os.path.join(folder_path, subpath_optimal_params,f"{file_prefix}_best_params.pkl"), 'wb') as f: + pickle.dump(params_history[-1], f) + + + def main(args: argparse.Namespace) -> None: """Main function to run the EM algorithm for a mixture model""" params = load_yaml_params(args.params) inference_data = load_patient_data(args.input) + multiple_fit = args.multi_fit # ugly, but necessary for pickling global MIXTURE @@ -156,22 +207,29 @@ def main(args: argparse.Namespace) -> None: else: raise "Only Unilateral has been implemented so far" - # emcee does not support numpy's new random number generator yet. - rng = np.random.default_rng(params["em"].get("seed", 42)) - starting_values = {k: rng.uniform() for k in MIXTURE.get_params()} - MIXTURE.set_params(**starting_values) - MIXTURE.normalize_mixture_coefs() - tolerance = params['model'].get('likelihood_tolerance', 0.01) - history_dir = params['general']['history_dir'] - logger.info(f"Saving history to {history_dir}.") - params_history, likelihood_history = run_EM(tolerance = tolerance, history_dir = history_dir) + rng = np.random.default_rng(params["em"].get("seed", 42)) + if args.multi_fit: + with ProcessPoolExecutor(max_workers = 10) as executor: + futures = [ + executor.submit(process_dataset, i, dataset, initial_params, history_dir) + for i, dataset in enumerate(datasets) + ] + else: + starting_values = {k: rng.uniform() for k in MIXTURE.get_params()} + MIXTURE.set_params(**starting_values) + MIXTURE.normalize_mixture_coefs() + tolerance = params['model'].get('likelihood_tolerance', 0.01) + history_dir = params['general']['history_dir'] + logger.info(f"Saving history to {history_dir}.") + params_history, likelihood_history = run_EM(tolerance = tolerance, history_dir = history_dir) + llh_history = pd.DataFrame(likelihood_history) llh_history.columns = ['likelihoods'] - llh_history.to_csv(history_dir + '/llh.csv', index=False) + llh_history.to_csv(history_dir + '/original' + '/llh.csv', index=False) param_history = pd.DataFrame(params_history) - param_history.to_csv(history_dir + '/params.csv', index=False) - MIXTURE.get_mixture_coefs().to_csv(history_dir + '/mixture_coef.csv', index=False) + param_history.to_csv(history_dir + '/original' + '/params.csv', index=False) + MIXTURE.get_mixture_coefs().to_csv(history_dir + '/original' + '/mixture_coef.csv', index=False) if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) From e5bebbb445d3668cf5ffe284bb303d9d2ef9f526 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Fri, 5 Sep 2025 14:07:22 +0200 Subject: [PATCH 19/20] change! updated some scripts for mixture not fully functional yet --- lyscripts/data/__init__.py | 3 +- lyscripts/data/__main__.py | 3 +- lyscripts/data/bootstrap.py | 101 ++++++++++++++++++ lyscripts/mixture_fit.py | 99 +++++++++-------- ...ixture_sample.py => mixture_sample_old.py} | 0 lyscripts/plot/simplex_plot.py | 2 + 6 files changed, 162 insertions(+), 46 deletions(-) create mode 100644 lyscripts/data/bootstrap.py rename lyscripts/{mixture_sample.py => mixture_sample_old.py} (100%) diff --git a/lyscripts/data/__init__.py b/lyscripts/data/__init__.py index 00754db..142105f 100644 --- a/lyscripts/data/__init__.py +++ b/lyscripts/data/__init__.py @@ -10,7 +10,7 @@ import argparse from pathlib import Path -from lyscripts.data import enhance, filter, generate, join, lyproxify, split +from lyscripts.data import enhance, filter, generate, join, lyproxify, split, bootstrap def _add_parser( @@ -31,3 +31,4 @@ def _add_parser( lyproxify._add_parser(subparsers, help_formatter=parser.formatter_class) split._add_parser(subparsers, help_formatter=parser.formatter_class) filter._add_parser(subparsers, help_formatter=parser.formatter_class) + bootstrap._add_parser(subparsers, help_formatter=parser.formatter_class) diff --git a/lyscripts/data/__main__.py b/lyscripts/data/__main__.py index 0c6f67c..7707540 100644 --- a/lyscripts/data/__main__.py +++ b/lyscripts/data/__main__.py @@ -1,7 +1,7 @@ import argparse from lyscripts import RichDefaultHelpFormatter, exit_cli -from lyscripts.data import enhance, filter, generate, join, split +from lyscripts.data import enhance, filter, generate, join, split, bootstrap # I need another __main__ guard here, because otherwise pdoc tries to run this if __name__ == "__main__": @@ -20,6 +20,7 @@ join._add_parser(subparsers, help_formatter=parser.formatter_class) split._add_parser(subparsers, help_formatter=parser.formatter_class) filter._add_parser(subparsers, help_formatter=parser.formatter_class) + bootstrap._add_parser(subparsers, help_formatter=parser.formatter_class) args = parser.parse_args() args.run_main(args) diff --git a/lyscripts/data/bootstrap.py b/lyscripts/data/bootstrap.py new file mode 100644 index 0000000..17871a8 --- /dev/null +++ b/lyscripts/data/bootstrap.py @@ -0,0 +1,101 @@ +""" +Learn the spread probabilities of the HMM for lymphatic tumor progression using +the preprocessed data as input and the mixture model. +""" +# pylint: disable=logging-fstring-interpolation +import argparse +import logging +import os +import numpy as np +from sklearn.utils import resample +from pathlib import Path +import pandas as pd +from lyscripts.utils import ( + load_patient_data, + load_yaml_params, +) + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add a parser to the ``subparsers`` action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument( + "--input", type=Path, + help="Path to a LyProX-style CSV file" + ) + parser.add_argument( + "--output", type=Path, + help="Folder destination to save LyProX-style CSV files" + ) + parser.add_argument( + "-p", "--params", default="params.yaml", type=Path, + help="Path to parameter file" + ) + + parser.set_defaults(run_main=main) + +def proportional_bootstrap(df, group_col, n_bootstraps, folder_path = None): + """Produce n_bootstraps bootstrapped datasets from the original DataFrame. + this keeps the number of patients per subsite constant. + + Args: + df (pd.DataFrame): The original DataFrame to bootstrap from. + group_col (str): The name of the column to group by. + n_bootstraps (int): The number of bootstrapped datasets to create. + + Returns: + list[pd.DataFrame]: A list of bootstrapped DataFrames. + """ + datasets = [] + group_sizes = df[group_col].value_counts(normalize=True) + total_n = len(df) + os.makedirs(folder_path, exist_ok=True) + for _ in range(n_bootstraps): + samples = [] + for group, proportion in group_sizes.items(): + n_samples = int(np.round(proportion * total_n)) + group_df = df[df[group_col] == group] + boot_group = resample(group_df, replace=True, n_samples=n_samples) + samples.append(boot_group) + boot_df = pd.concat(samples).sample(frac=1).reset_index(drop=True) # optional shuffle + datasets.append(boot_df) + if folder_path is not None: + for i, dataset in enumerate(datasets): + file_path = os.path.join(folder_path, f'dataset_resample_{i}.csv') + dataset.to_csv(file_path, index=False) + else: + return datasets + +def main(args: argparse.Namespace) -> None: + """Main function to sample parameters for a mixture model""" + input_table = load_patient_data(args.input) + params = load_yaml_params(args.params) + if ('tumor', 'core', 'subsite') in input_table.columns: + group_col = ('tumor', 'core', 'subsite') + elif ('tumor', '1', 'subsite') in input_table.columns: + group_col = ('tumor', '1', 'subsite') + else: + logger.error("No 'subsite' column found in the input data.") + proportional_bootstrap(input_table, group_col=group_col, n_bootstraps=params["sampling"]["n_bootstraps"], folder_path=args.output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index afa0463..01922c3 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -68,7 +68,7 @@ def _add_arguments(parser: argparse.ArgumentParser): help="Whether to fit multiple models for later uncertainty evaluation" ) parser.add_argument( - "-sp", "--starting_point", type=bool, default=None, + "-sp", "--starting_point", type=Path, default=None, help="Starting point for optimization if we do not want to start from a random point" ) @@ -100,6 +100,7 @@ def check_convergence(params_history, likelihood_history, steps_back_list, absol def run_EM(tolerance, history_dir = None): """Run the EM algorithm to determine the optimal parameters. """ + os.makedirs(history_dir, exist_ok=True) is_converged = False iteration = 0 params = MIXTURE.get_params() @@ -113,7 +114,8 @@ def run_EM(tolerance, history_dir = None): while not is_converged: logger.info(f"Iteration: {iteration}") logger.info(f"Likelihood: {likelihood_history[-1]}") - latent = expectation(MIXTURE, params) + latent = expectation(MIXTURE, params, log = True) + MIXTURE.set_resps(np.exp(latent)) params = maximization(MIXTURE, latent) # Append current params and likelihood to history @@ -130,10 +132,12 @@ def run_EM(tolerance, history_dir = None): if iteration >= 3: # Ensure enough history is available is_converged = check_convergence(params_history, likelihood_history,list(range(1,look_back_steps+1)),tolerance) iteration += 1 + df = pd.DataFrame.from_dict(MIXTURE.get_params(), orient='index', columns=['value']) + df.to_csv(history_dir + '/optimal_params.csv') return params_history, likelihood_history -def process_dataset(dataset, initial_params, folder_path, index = 0, look_back_steps=3): +def process_dataset(dataset, folder_path, initial_params, model_build_params, index, look_back_steps=3): os.makedirs(folder_path, exist_ok=True) subpath_optimal_params = 'optimal_params' os.makedirs(os.path.join(folder_path, subpath_optimal_params), exist_ok=True) @@ -142,17 +146,21 @@ def process_dataset(dataset, initial_params, folder_path, index = 0, look_back_s subpath_likelihood_history = 'likelihood_history' os.makedirs(os.path.join(folder_path, subpath_likelihood_history), exist_ok=True) + logger.info(f"Starting dataset {index}") - mixture = create_mixture(params) + mixture = create_mixture(model_build_params) + mapping = model_build_params["model"].get("mapping", None) + if isinstance(mixture.components[0], models.Unilateral): + mixture.load_patient_data(dataset, split_by= model_build_params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) + assign_modalities(model=mixture, config=model_build_params.get("inference_modalities", {})) + else: + raise ValueError("Only Unilateral has been implemented so far") - mixture.load_patient_data( - dataset, - split_by=("tumor", "1", "subsite"), - mapping=lambda x: x, - ) - assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) - mixture.set_params(**initial_params[index]) - params = initial_params[index].copy() + mixture.set_params(**initial_params) + mixture.normalize_mixture_coefs() + tolerance = model_build_params['model'].get('likelihood_tolerance', 0.01) + mixture.set_params(**initial_params) + params = initial_params.copy() mixture.normalize_mixture_coefs() params_history = [params.copy()] @@ -161,6 +169,8 @@ def process_dataset(dataset, initial_params, folder_path, index = 0, look_back_s is_converged = False count = 0 logger.info(f"[Dataset {index}] started") + file_prefix = f"dataset_{index}" + while not is_converged: latent = expectation(mixture, params, log = True) @@ -172,64 +182,65 @@ def process_dataset(dataset, initial_params, folder_path, index = 0, look_back_s llh_history = pd.DataFrame(likelihood_history) llh_history.columns = ['likelihoods'] - llh_history.to_csv(os.path.join(folder_path, subpath_likelihood_history, f"{file_prefix}_likelihood_history.pkl"), index=False) + llh_history.to_csv(os.path.join(folder_path, subpath_likelihood_history, f"{file_prefix}_likelihood_history.csv"), index=False) param_history = pd.DataFrame(params_history) - param_history.to_csv(os.path.join(folder_path, subpath_params_history, f"{file_prefix}_param_history.pkl"), index=False) + param_history.to_csv(os.path.join(folder_path, subpath_params_history, f"{file_prefix}_param_history.csv"), index=False) if count >= look_back_steps: - is_converged = check_convergence(params_history, likelihood_history, list(range(1, look_back_steps + 1))) + is_converged = check_convergence(params_history, likelihood_history, list(range(1, look_back_steps + 1)), tolerance) count += 1 logger.info(f"[Dataset {index}] Converged after {count} steps") - file_prefix = f"dataset_{index}" - - with open(os.path.join(folder_path, subpath_optimal_params,f"{file_prefix}_best_params.pkl"), 'wb') as f: - pickle.dump(params_history[-1], f) - + df = pd.DataFrame.from_dict(mixture.get_params(), orient='index', columns=['value']) + df.to_csv(os.path.join(folder_path, subpath_optimal_params, f"{file_prefix}_optimal_params.csv")) def main(args: argparse.Namespace) -> None: """Main function to run the EM algorithm for a mixture model""" params = load_yaml_params(args.params) - inference_data = load_patient_data(args.input) - multiple_fit = args.multi_fit - - # ugly, but necessary for pickling global MIXTURE MIXTURE = create_mixture(params) - mapping = params["model"].get("mapping", None) - if isinstance(MIXTURE.components[0], models.Unilateral): - side = params["model"].get("side", "ipsi") - MIXTURE.load_patient_data(inference_data, split_by= params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) - assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) - + if args.starting_point is None: + rng = np.random.default_rng(params["em"].get("seed", 42)) + starting_values = {k: rng.uniform() for k in MIXTURE.get_params()} else: - raise "Only Unilateral has been implemented so far" - - rng = np.random.default_rng(params["em"].get("seed", 42)) + logger.info(f"Using starting point from {args.starting_point}") + starting_df = pd.read_csv(args.starting_point, index_col=0) # Use first column as index + starting_values = starting_df['value'].to_dict() + if args.multi_fit: - with ProcessPoolExecutor(max_workers = 10) as executor: + datasets = [] + history_dir = params['sampling']['output_path'] + os.makedirs(history_dir, exist_ok=True) + for i in range(params['sampling']['n_bootstraps']): + file_path = os.path.join(args.input, f"dataset_resample_{i}.csv") + if os.path.exists(file_path): + loaded_dataset = pd.read_csv(file_path, header=[0, 1, 2]) + datasets.append(loaded_dataset) + with ProcessPoolExecutor(max(1, os.cpu_count() - 2)) as executor: futures = [ - executor.submit(process_dataset, i, dataset, initial_params, history_dir) + executor.submit(process_dataset, dataset, history_dir, starting_values, params, i) for i, dataset in enumerate(datasets) ] else: - starting_values = {k: rng.uniform() for k in MIXTURE.get_params()} + inference_data = load_patient_data(args.input) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE.components[0], models.Unilateral): + MIXTURE.load_patient_data(inference_data, split_by= params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) + assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) + + else: + raise ValueError("Only Unilateral has been implemented so far") + MIXTURE.set_params(**starting_values) MIXTURE.normalize_mixture_coefs() tolerance = params['model'].get('likelihood_tolerance', 0.01) - history_dir = params['general']['history_dir'] + history_dir = params['fitting']['folder_path'] logger.info(f"Saving history to {history_dir}.") params_history, likelihood_history = run_EM(tolerance = tolerance, history_dir = history_dir) - - llh_history = pd.DataFrame(likelihood_history) - llh_history.columns = ['likelihoods'] - llh_history.to_csv(history_dir + '/original' + '/llh.csv', index=False) - param_history = pd.DataFrame(params_history) - param_history.to_csv(history_dir + '/original' + '/params.csv', index=False) - MIXTURE.get_mixture_coefs().to_csv(history_dir + '/original' + '/mixture_coef.csv', index=False) if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) diff --git a/lyscripts/mixture_sample.py b/lyscripts/mixture_sample_old.py similarity index 100% rename from lyscripts/mixture_sample.py rename to lyscripts/mixture_sample_old.py diff --git a/lyscripts/plot/simplex_plot.py b/lyscripts/plot/simplex_plot.py index 6ec977a..3db6a91 100644 --- a/lyscripts/plot/simplex_plot.py +++ b/lyscripts/plot/simplex_plot.py @@ -97,6 +97,8 @@ def _add_arguments(parser: argparse.ArgumentParser): # bottom_ax.grid(axis="x", alpha=0.5, color=USZ["gray"], linestyle=":") # plt.savefig(args.output, bbox_inches="tight", dpi=300) +def plot_2d_simplex(mixture_df, data): + def plot_3d_simplex(mixture_df, data, output, component_names = False): data = pd.read_csv(data, header=[0, 1, 2]) subsites = list(mixture_df.columns) From 60a20a4d3bad97f685263b76742c70d70e149188 Mon Sep 17 00:00:00 2001 From: YoelPH Date: Mon, 8 Sep 2025 11:46:07 +0200 Subject: [PATCH 20/20] fix: scripts are now funcitonal + ran pre commit --- docs/source/conf.py | 44 +-- lyscripts/__init__.py | 34 ++- lyscripts/__main__.py | 4 +- lyscripts/app/__init__.py | 4 +- lyscripts/app/prevalence.py | 31 +- lyscripts/compute/__init__.py | 4 +- lyscripts/compute/prevalences.py | 54 ++-- lyscripts/compute/priors.py | 40 ++- lyscripts/compute/risks.py | 54 ++-- lyscripts/compute/utils.py | 11 +- lyscripts/data/__init__.py | 6 +- lyscripts/data/__main__.py | 2 +- lyscripts/data/accessor.py | 5 +- lyscripts/data/bootstrap.py | 53 ++-- lyscripts/data/enhance.py | 134 ++++----- lyscripts/data/filter.py | 17 +- lyscripts/data/generate.py | 40 +-- lyscripts/data/join.py | 24 +- lyscripts/data/lyproxify.py | 65 ++-- lyscripts/data/split.py | 34 +-- lyscripts/data/utils.py | 5 +- lyscripts/decorators.py | 8 +- lyscripts/evaluate.py | 56 ++-- lyscripts/mixture_fit.py | 234 +++++++++------ lyscripts/mixture_sample_old.py | 85 ++++-- lyscripts/plot/__init__.py | 22 +- lyscripts/plot/__main__.py | 17 +- lyscripts/plot/corner.py | 21 +- lyscripts/plot/histograms.py | 59 ++-- lyscripts/plot/mixture_comp_uncertainty.py | 177 +++++++++++ lyscripts/plot/mixture_plot.py | 43 +-- lyscripts/plot/mixture_sampling_plotter.py | 270 ++++++++++------- lyscripts/plot/simplex_plot.py | 328 +++++++++++++-------- lyscripts/plot/thermo_int.py | 55 ++-- lyscripts/plot/utils.py | 148 ++++++---- lyscripts/sample.py | 102 ++++--- lyscripts/scenario.py | 97 +++--- lyscripts/temp_schedule.py | 42 +-- lyscripts/utils.py | 70 +++-- tests/data/join_test.py | 16 +- tests/plot/plot_utils_test.py | 237 ++++++++------- tests/predict/predict_utils_test.py | 14 +- tests/predict/prevalences_test.py | 34 +-- tests/run_doctests.py | 5 +- tests/sample_test.py | 19 +- tests/utils_test.py | 13 +- 46 files changed, 1737 insertions(+), 1100 deletions(-) create mode 100644 lyscripts/plot/mixture_comp_uncertainty.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 2486799..ded3bcd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -7,10 +7,10 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'lyscripts' -copyright = '2024, Roman Ludwig' -author = 'Roman Ludwig' -gh_username = 'rmnldwg' +project = "lyscripts" +copyright = "2024, Roman Ludwig" +author = "Roman Ludwig" +gh_username = "rmnldwg" version = lyscripts.__version__ release = lyscripts.__version__ @@ -18,36 +18,36 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ - 'sphinx.ext.intersphinx', - 'sphinx.ext.autodoc', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinxcontrib.programoutput', - 'sphinx.ext.napoleon', - 'm2r', + "sphinx.ext.intersphinx", + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinxcontrib.programoutput", + "sphinx.ext.napoleon", + "m2r", ] # markdown to reST -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] -templates_path = ['_templates'] +templates_path = ["_templates"] exclude_patterns = [] # document classes and their constructors -autoclass_content = 'class' +autoclass_content = "class" # sort members by source -autodoc_member_order = 'bysource' +autodoc_member_order = "bysource" # show type hints -autodoc_typehints = 'signature' +autodoc_typehints = "signature" # create links to other projects intersphinx_mapping = { - 'python': ('https://docs.python.org/3.10', None), - 'lymph': ('https://lymph-model.readthedocs.io/en/latest/', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), - 'numpy': ('https://numpy.org/doc/stable/', None), + "python": ("https://docs.python.org/3.10", None), + "lymph": ("https://lymph-model.readthedocs.io/en/latest/", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "numpy": ("https://numpy.org/doc/stable/", None), } @@ -57,7 +57,7 @@ # a list of builtin themes. # -html_theme = 'sphinx_book_theme' +html_theme = "sphinx_book_theme" html_theme_options = { "repository_url": f"https://github.com/{gh_username}/{project}", "repository_branch": "main", @@ -76,7 +76,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['./_static'] +html_static_path = ["./_static"] html_css_files = [ "css/custom.css", ] diff --git a/lyscripts/__init__.py b/lyscripts/__init__.py index 0d2dd98..53e9276 100644 --- a/lyscripts/__init__.py +++ b/lyscripts/__init__.py @@ -1,10 +1,10 @@ -""" -This is the top-level module of the `lyscripts` package. It contains the +"""This is the top-level module of the `lyscripts` package. It contains the :py:func:`.main` function that is used to start the command line interface (CLI) for the package. Also, it configures the logging system and sets the metadate of the package. """ + import argparse import logging import re @@ -13,7 +13,16 @@ import rich from rich_argparse import RichHelpFormatter -from lyscripts import plot, app, compute, data, evaluate, mixture_fit, temp_schedule, mixture_sample +from lyscripts import ( + app, + compute, + data, + evaluate, + mixture_fit, + mixture_sample, + plot, + temp_schedule, +) from lyscripts._version import version from lyscripts.utils import CustomRichHandler, console @@ -41,6 +50,7 @@ class RichDefaultHelpFormatter( .. _rich: https://rich.readthedocs.io/en/stable/introduction.html """ + def _rich_fill_text( self, text: rich.text.Text, @@ -67,17 +77,11 @@ def _rich_fill_text( RichDefaultHelpFormatter.styles["argparse.syntax"] = "red" RichDefaultHelpFormatter.styles["argparse.formula"] = "green" -RichDefaultHelpFormatter.highlights.append( - r"\$(?P[^$]*)\$" -) +RichDefaultHelpFormatter.highlights.append(r"\$(?P[^$]*)\$") RichDefaultHelpFormatter.styles["argparse.bold"] = "bold" -RichDefaultHelpFormatter.highlights.append( - r"\*(?P[^*]*)\*" -) +RichDefaultHelpFormatter.highlights.append(r"\*(?P[^*]*)\*") RichDefaultHelpFormatter.styles["argparse.italic"] = "italic" -RichDefaultHelpFormatter.highlights.append( - r"_(?P[^_]*)_" -) +RichDefaultHelpFormatter.highlights.append(r"_(?P[^_]*)_") def exit_cli(args: argparse.Namespace): @@ -97,11 +101,11 @@ def main(): ) parser.set_defaults(run_main=exit_cli) parser.add_argument( - "-v", "--version", action="store_true", - help="Display the version of lyscripts" + "-v", "--version", action="store_true", help="Display the version of lyscripts" ) parser.add_argument( - "--log-level", default="INFO", + "--log-level", + default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], ) diff --git a/lyscripts/__main__.py b/lyscripts/__main__.py index d9ba03c..1bc2e10 100644 --- a/lyscripts/__main__.py +++ b/lyscripts/__main__.py @@ -1,7 +1,7 @@ -""" -Utility for performing common tasks w.r.t. the inference and prediction tasks one +"""Utility for performing common tasks w.r.t. the inference and prediction tasks one can use the `lymph` package for. """ + from lyscripts import main # I need another __main__ guard here, because otherwise pdoc tries to run this diff --git a/lyscripts/app/__init__.py b/lyscripts/app/__init__.py index 7c3db75..3ed505e 100644 --- a/lyscripts/app/__init__.py +++ b/lyscripts/app/__init__.py @@ -1,8 +1,8 @@ -""" -Module containing scripts to run different `streamlit`_ applications. +"""Module containing scripts to run different `streamlit`_ applications. .. _streamlit: https://streamlit.io/ """ + import argparse from pathlib import Path diff --git a/lyscripts/app/prevalence.py b/lyscripts/app/prevalence.py index 2ee0c4d..83a9b0c 100644 --- a/lyscripts/app/prevalence.py +++ b/lyscripts/app/prevalence.py @@ -1,5 +1,4 @@ -""" -A `streamlit`_ app for computing, displaying, and reproducing prevalence estimates. +"""A `streamlit`_ app for computing, displaying, and reproducing prevalence estimates. The primary goal with this little GUI is that one can quickly draft some data & prediction comparisons visually and then copy & paste the configuration in YAML format @@ -7,6 +6,7 @@ .. _streamlit: https://streamlit.io/ """ + import argparse import sys from pathlib import Path @@ -51,10 +51,7 @@ def _add_arguments(parser: argparse.ArgumentParser): .. _streamlit: https://streamlit.io/ """ - parser.add_argument( - "--message", type=str, - help="Print our this little message." - ) + parser.add_argument("--message", type=str, help="Print our this little message.") parser.set_defaults(run_main=launch_streamlit) @@ -122,7 +119,7 @@ def interactive_load(streamlit): type=["csv"], help="CSV spreadsheet containing lymphatic patterns of progression", ) - header_rows = [0,1] if is_unilateral else [0,1,2] + header_rows = [0, 1] if is_unilateral else [0, 1, 2] patient_data = load_patient_data(data_file, header=header_rows) streamlit.write("---") @@ -130,7 +127,7 @@ def interactive_load(streamlit): samples_file = streamlit.file_uploader( label="HDF5 sample file", type=["hdf5", "hdf", "h5"], - help="HDF5 file containing the samples." + help="HDF5 file containing the samples.", ) samples = load_model_samples(samples_file) @@ -138,10 +135,7 @@ def interactive_load(streamlit): def interactive_pattern( - streamlit, - is_unilateral: bool, - lnls: list[str], - side: str + streamlit, is_unilateral: bool, lnls: list[str], side: str ) -> dict[str, bool]: """Create a `streamlit`_ panel for specifying an involvement pattern. @@ -179,7 +173,7 @@ def interactive_additional_params( The respective controls are presented next to each other in three dedicated columns. """ - control_cols = streamlit.columns([1,2,1,1,1]) + control_cols = streamlit.columns([1, 2, 1, 1, 1]) t_stage = control_cols[0].selectbox( label="T-category", options=model.diag_time_dists.keys(), @@ -270,15 +264,16 @@ def add_current_scenario( session_state["contents"].append(beta_posterior) session_state["contents"].append(histogram) - session_state["scenarios"].append({ - "pattern": reduce_pattern(pattern), **prevs_kwargs - }) + session_state["scenarios"].append( + {"pattern": reduce_pattern(pattern), **prevs_kwargs} + ) def main(args: argparse.Namespace): """The main function that contains the `streamlit`_ code and functionality. - .. _streamlit: https://streamlit.io/""" + .. _streamlit: https://streamlit.io/ + """ import streamlit as st st.title("Prevalence") @@ -335,7 +330,7 @@ def main(args: argparse.Namespace): ) fig, ax = plt.subplots() - draw(axes=ax, contents=st.session_state.get("contents", []), xlims=(0., 100.)) + draw(axes=ax, contents=st.session_state.get("contents", []), xlims=(0.0, 100.0)) ax.legend() st.pyplot(fig) diff --git a/lyscripts/compute/__init__.py b/lyscripts/compute/__init__.py index ccc12e9..29a678f 100644 --- a/lyscripts/compute/__init__.py +++ b/lyscripts/compute/__init__.py @@ -1,8 +1,8 @@ -""" -With the commands of this module, a user may compute prior and posterior state +"""With the commands of this module, a user may compute prior and posterior state distributions from drawn samples of a model. This can in turn speed up the computation of risks and prevalences. """ + import argparse from pathlib import Path diff --git a/lyscripts/compute/prevalences.py b/lyscripts/compute/prevalences.py index 913b26c..c6179ac 100644 --- a/lyscripts/compute/prevalences.py +++ b/lyscripts/compute/prevalences.py @@ -15,10 +15,13 @@ observed involvement pattern. Warning: +------- The command skips the computation of the priors if it finds them in the cache. But this cache only accounts for the scenario, *NOT* the samples. So, if the samples change, you need to force a recomputation of the priors (e.g., by deleting them). + """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -35,7 +38,6 @@ from lyscripts import utils from lyscripts.compute.priors import compute_priors_using_cache from lyscripts.compute.utils import HDF5FileCache, get_modality_subset -from lyscripts.data import accessor # nopycln: import from lyscripts.scenario import Scenario, add_scenario_arguments logger = logging.getLogger(__name__) @@ -58,41 +60,45 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments needed to run this script to a ``subparsers`` instance.""" parser.add_argument( - "--priors", type=Path, required=True, + "--priors", + type=Path, + required=True, help=( "Path to the prior state distributions (HDF5 file). If samples are " "provided, this will be used as output to store the computed posteriors. " "If no samples are provided, this will be used as input to load the priors." - ) + ), ) parser.add_argument( - "--prevalences", type=Path, required=True, - help="Path to the HDF5 file for storing the computed prevalences." + "--prevalences", + type=Path, + required=True, + help="Path to the HDF5 file for storing the computed prevalences.", ) parser.add_argument( - "--data", type=Path, required=False, - help="Path to the patient data (CSV file)." + "--data", type=Path, required=False, help="Path to the patient data (CSV file)." ) parser.add_argument( - "--params", default="./params.yaml", type=Path, - help="Path to parameter file defining the model (YAML)." + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file defining the model (YAML).", ) parser.add_argument( - "--scenarios", type=Path, required=False, + "--scenarios", + type=Path, + required=False, help=( "Path to a YAML file containing a `scenarios` key with a list of " "involvement scenarios to compute the posteriors for." - ) + ), ) add_scenario_arguments(parser, for_comp="prevalences") parser.set_defaults(run_main=main) -def does_midext_match( - data: pd.DataFrame, - midext: bool | None = None -) -> pd.Index: +def does_midext_match(data: pd.DataFrame, midext: bool | None = None) -> pd.Index: """Return indices of ``data`` where ``midline_ext`` of the patients matches.""" midext_col = data["tumor", "1", "extension"] if midext is None: @@ -112,8 +118,10 @@ def compute_observed_prevalence( T-stages defined in the ``scenario``. Warning: + ------- When computing prevalences for unilateral models, the contralateral diagnosis will still be considered for computing the prevalence in the *data*. + """ # when looking at the data, we always consider both sides is_uni = scenario.is_uni @@ -147,7 +155,7 @@ def observe_prevalence_using_cache( data: pd.DataFrame, scenario: Scenario, cache: HDF5FileCache, - mapping: dict[int, Any] | Callable[[int], Any] = None + mapping: dict[int, Any] | Callable[[int], Any] = None, ): """Compute and cache the observed prevalence for a given ``scenario``.""" num_match, num_total = compute_observed_prevalence( @@ -239,12 +247,14 @@ def main(args: argparse.Namespace): if args.scenarios is None: # create a single scenario from the stdin arguments... - scenarios = [Scenario.from_namespace( - namespace=args, - lnls=lnls, - is_uni=isinstance(model, models.Unilateral), - side=params["model"].get("side", "ipsi"), - )] + scenarios = [ + Scenario.from_namespace( + namespace=args, + lnls=lnls, + is_uni=isinstance(model, models.Unilateral), + side=params["model"].get("side", "ipsi"), + ) + ] num_scens = len(scenarios) else: # ...or load the scenarios from a YAML file diff --git a/lyscripts/compute/priors.py b/lyscripts/compute/priors.py index 4f8283e..59e98be 100644 --- a/lyscripts/compute/priors.py +++ b/lyscripts/compute/priors.py @@ -1,5 +1,4 @@ -""" -Given samples drawn during an MCMC round, compute the (prior) state distribution for +"""Given samples drawn during an MCMC round, compute the (prior) state distribution for each sample. This may then later on be used to compute risks and prevalences more quickly. @@ -8,6 +7,7 @@ distribution that was used to marginalize over them, as well as the model's computation mode (hidden Markov model or Bayesian network). """ + import argparse import logging from pathlib import Path @@ -40,23 +40,31 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments needed to run this script to a ``subparsers`` instance.""" parser.add_argument( - "--samples", type=Path, required=True, - help="Path to the drawn samples (HDF5 file)." + "--samples", + type=Path, + required=True, + help="Path to the drawn samples (HDF5 file).", ) parser.add_argument( - "--priors", type=Path, required=True, - help="Path to file for storing the computed prior distributions." + "--priors", + type=Path, + required=True, + help="Path to file for storing the computed prior distributions.", ) parser.add_argument( - "--params", type=Path, required=True, - help="Path to parameter file defining the model (YAML)." + "--params", + type=Path, + required=True, + help="Path to parameter file defining the model (YAML).", ) parser.add_argument( - "--scenarios", type=Path, required=False, + "--scenarios", + type=Path, + required=False, help=( "Path to a YAML file containing a `scenarios` key with a list of " "diagnosis scenarios to compute the posteriors for." - ) + ), ) add_scenario_arguments(parser, for_comp="priors") @@ -98,10 +106,12 @@ def compute_priors_using_cache( total=len(samples), ): model.set_params(*sample) - priors.append(sum( - model.state_dist(t_stage=t, mode=scenario.mode) * p - for t, p in zip(scenario.t_stages, scenario.t_stages_dist) - )) + priors.append( + sum( + model.state_dist(t_stage=t, mode=scenario.mode) * p + for t, p in zip(scenario.t_stages, scenario.t_stages_dist, strict=False) + ) + ) priors = np.stack(priors) cache[priors_hash] = (priors, scenario.as_dict("priors")) @@ -109,7 +119,7 @@ def compute_priors_using_cache( def main(args: argparse.Namespace): - """compute the prior state distribution for each sample.""" + """Compute the prior state distribution for each sample.""" params = utils.load_yaml_params(args.params) if args.scenarios is None: diff --git a/lyscripts/compute/risks.py b/lyscripts/compute/risks.py index 4e6cf54..d5b4cce 100644 --- a/lyscripts/compute/risks.py +++ b/lyscripts/compute/risks.py @@ -1,10 +1,10 @@ -""" -Predict risks of involvements using the posteriors that were computed using the +"""Predict risks of involvements using the posteriors that were computed using the :py:mod:`.compute.posteriors` command. The structure of these scenarios is similar to how scenarios are defined for the :py:mod:`.compute.prevalences` script. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -39,24 +39,32 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments needed to run this script to a `subparsers` instance.""" parser.add_argument( - "--posteriors", type=Path, required=True, - help="Path to the computed posteriors (HDF5 file)." + "--posteriors", + type=Path, + required=True, + help="Path to the computed posteriors (HDF5 file).", ) parser.add_argument( - "--risks", type=Path, required=True, - help="Path to file for storing the computed risks." + "--risks", + type=Path, + required=True, + help="Path to file for storing the computed risks.", ) parser.add_argument( - "--params", type=Path, required=True, - help="Path to parameter file defining the model (YAML)." + "--params", + type=Path, + required=True, + help="Path to parameter file defining the model (YAML).", ) parser.add_argument( - "--scenarios", type=Path, required=False, + "--scenarios", + type=Path, + required=False, help=( "Path to a YAML file containing a `scenarios` key with a list of " "diagnosis scenarios and involvement patterns of interest to compute the " "risks for." - ) + ), ) add_scenario_arguments(parser, for_comp="risks") @@ -100,11 +108,13 @@ def compute_risks_using_cache( description="[blue]INFO [/blue]" + progress_desc, total=len(posteriors), ): - risks.append(model.marginalize( - involvement=scenario.involvement, - given_state_dist=posterior, - **kwargs, - )) + risks.append( + model.marginalize( + involvement=scenario.involvement, + given_state_dist=posterior, + **kwargs, + ) + ) risks = np.stack(risks) risks_cache[risks_hash] = (risks, scenario.as_dict("risks")) @@ -119,12 +129,14 @@ def main(args: argparse.Namespace): if args.scenarios is None: # create a single scenario from the stdin arguments... - scenarios = [Scenario.from_namespace( - namespace=args, - lnls=lnls, - is_uni=isinstance(model, models.Unilateral), - side=params["model"].get("side", "ipsi"), - )] + scenarios = [ + Scenario.from_namespace( + namespace=args, + lnls=lnls, + is_uni=isinstance(model, models.Unilateral), + side=params["model"].get("side", "ipsi"), + ) + ] num_scens = len(scenarios) else: # ...or load the scenarios from a YAML file diff --git a/lyscripts/compute/utils.py b/lyscripts/compute/utils.py index c50e1c8..09c5e00 100644 --- a/lyscripts/compute/utils.py +++ b/lyscripts/compute/utils.py @@ -1,6 +1,5 @@ -""" -Utilities for precomputing the priors and posteriors. -""" +"""Utilities for precomputing the priors and posteriors.""" + from pathlib import Path from typing import Any @@ -51,6 +50,7 @@ def get_modality_subset(diagnosis: dict[str, Any]) -> set[str]: class HDF5FileCache: """HDF5 file acting as a cache for expensive arrays.""" + def __init__(self, file_path: Path) -> None: """Initialize the cache with the given ``file_path``.""" file_path.parent.mkdir(parents=True, exist_ok=True) @@ -86,13 +86,14 @@ def reduce_pattern(pattern: dict[str, dict[str, bool]]) -> dict[str, dict[str, b but be shorter to store. Example: - + ------- >>> full = { ... "ipsi": {"I": None, "II": True, "III": None}, ... "contra": {"I": None, "II": None, "III": None}, ... } >>> reduce_pattern(full) {'ipsi': {'II': True}} + """ tmp_pattern = pattern.copy() reduced_pattern = {} @@ -116,10 +117,12 @@ def complete_pattern( contain ``True``, ``False`` or ``None``. Example: + ------- >>> pattern = {"ipsi": {"II": True}} >>> lnls = ["II", "III"] >>> complete_pattern(pattern, lnls) {'ipsi': {'II': True, 'III': None}, 'contra': {'II': None, 'III': None}} + """ if pattern is None: pattern = {} diff --git a/lyscripts/data/__init__.py b/lyscripts/data/__init__.py index 142105f..bf0c240 100644 --- a/lyscripts/data/__init__.py +++ b/lyscripts/data/__init__.py @@ -1,5 +1,4 @@ -""" -Provide a range of commands and functions related to managing CSV datasets on patterns +"""Provide a range of commands and functions related to managing CSV datasets on patterns of lymphatic progression. It helps transform raw CSV data of any form to be converted into our `LyProX`_ format, @@ -7,10 +6,11 @@ .. _LyProX: https://lyprox.org """ + import argparse from pathlib import Path -from lyscripts.data import enhance, filter, generate, join, lyproxify, split, bootstrap +from lyscripts.data import bootstrap, enhance, filter, generate, join, lyproxify, split def _add_parser( diff --git a/lyscripts/data/__main__.py b/lyscripts/data/__main__.py index 7707540..a540dda 100644 --- a/lyscripts/data/__main__.py +++ b/lyscripts/data/__main__.py @@ -1,7 +1,7 @@ import argparse from lyscripts import RichDefaultHelpFormatter, exit_cli -from lyscripts.data import enhance, filter, generate, join, split, bootstrap +from lyscripts.data import bootstrap, enhance, filter, generate, join, split # I need another __main__ guard here, because otherwise pdoc tries to run this if __name__ == "__main__": diff --git a/lyscripts/data/accessor.py b/lyscripts/data/accessor.py index 40eeb56..1a04cbb 100644 --- a/lyscripts/data/accessor.py +++ b/lyscripts/data/accessor.py @@ -1,8 +1,8 @@ -""" -Create a custom pandas accessor to handle `LyProX`_ style data. +"""Create a custom pandas accessor to handle `LyProX`_ style data. .. _LyProX: https://lyprox.org """ + from collections.abc import Callable from typing import Any @@ -50,6 +50,7 @@ class LyProXAccessor: .. _LyProX: https://lyprox.org """ + def __init__(self, obj: pd.DataFrame) -> None: self._validate(obj) self._obj = obj diff --git a/lyscripts/data/bootstrap.py b/lyscripts/data/bootstrap.py index 17871a8..c4c6907 100644 --- a/lyscripts/data/bootstrap.py +++ b/lyscripts/data/bootstrap.py @@ -1,15 +1,17 @@ -""" -Learn the spread probabilities of the HMM for lymphatic tumor progression using +"""Learn the spread probabilities of the HMM for lymphatic tumor progression using the preprocessed data as input and the mixture model. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging import os -import numpy as np -from sklearn.utils import resample from pathlib import Path + +import numpy as np import pandas as pd +from sklearn.utils import resample + from lyscripts.utils import ( load_patient_data, load_yaml_params, @@ -34,32 +36,35 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" + parser.add_argument("--input", type=Path, help="Path to a LyProX-style CSV file") parser.add_argument( - "--input", type=Path, - help="Path to a LyProX-style CSV file" - ) - parser.add_argument( - "--output", type=Path, - help="Folder destination to save LyProX-style CSV files" + "--output", type=Path, help="Folder destination to save LyProX-style CSV files" ) parser.add_argument( - "-p", "--params", default="params.yaml", type=Path, - help="Path to parameter file" + "-p", + "--params", + default="params.yaml", + type=Path, + help="Path to parameter file", ) parser.set_defaults(run_main=main) -def proportional_bootstrap(df, group_col, n_bootstraps, folder_path = None): + +def proportional_bootstrap(df, group_col, n_bootstraps, folder_path=None): """Produce n_bootstraps bootstrapped datasets from the original DataFrame. this keeps the number of patients per subsite constant. Args: + ---- df (pd.DataFrame): The original DataFrame to bootstrap from. group_col (str): The name of the column to group by. n_bootstraps (int): The number of bootstrapped datasets to create. Returns: + ------- list[pd.DataFrame]: A list of bootstrapped DataFrames. + """ datasets = [] group_sizes = df[group_col].value_counts(normalize=True) @@ -72,26 +77,34 @@ def proportional_bootstrap(df, group_col, n_bootstraps, folder_path = None): group_df = df[df[group_col] == group] boot_group = resample(group_df, replace=True, n_samples=n_samples) samples.append(boot_group) - boot_df = pd.concat(samples).sample(frac=1).reset_index(drop=True) # optional shuffle + boot_df = ( + pd.concat(samples).sample(frac=1).reset_index(drop=True) + ) # optional shuffle datasets.append(boot_df) if folder_path is not None: for i, dataset in enumerate(datasets): - file_path = os.path.join(folder_path, f'dataset_resample_{i}.csv') + file_path = os.path.join(folder_path, f"dataset_resample_{i}.csv") dataset.to_csv(file_path, index=False) else: return datasets + def main(args: argparse.Namespace) -> None: """Main function to sample parameters for a mixture model""" input_table = load_patient_data(args.input) params = load_yaml_params(args.params) - if ('tumor', 'core', 'subsite') in input_table.columns: - group_col = ('tumor', 'core', 'subsite') - elif ('tumor', '1', 'subsite') in input_table.columns: - group_col = ('tumor', '1', 'subsite') + if ("tumor", "core", "subsite") in input_table.columns: + group_col = ("tumor", "core", "subsite") + elif ("tumor", "1", "subsite") in input_table.columns: + group_col = ("tumor", "1", "subsite") else: logger.error("No 'subsite' column found in the input data.") - proportional_bootstrap(input_table, group_col=group_col, n_bootstraps=params["sampling"]["n_bootstraps"], folder_path=args.output) + proportional_bootstrap( + input_table, + group_col=group_col, + n_bootstraps=params["sampling"]["n_bootstraps"], + folder_path=args.output, + ) if __name__ == "__main__": diff --git a/lyscripts/data/enhance.py b/lyscripts/data/enhance.py index 0abd858..03dbfd4 100644 --- a/lyscripts/data/enhance.py +++ b/lyscripts/data/enhance.py @@ -1,5 +1,4 @@ -""" -Enhance a LyProX-style CSV dataset in two ways: +"""Enhance a LyProX-style CSV dataset in two ways: 1. Add consensus diagnosis based on all available modalities using on of two methods: ``max_llh`` infers the most likely true state of involvement given only the available @@ -11,6 +10,7 @@ correct values. Conversely, if e.g. LNL II is reported to be healthy, we can assume the sublevels IIa and IIb would have been reported as healthy, too. """ + # pylint: disable=singleton-comparison,logging-fstring-interpolation import argparse import logging @@ -51,35 +51,44 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" + parser.add_argument("input", type=Path, help="Path to a LyProX-style CSV file") parser.add_argument( - "input", type=Path, - help="Path to a LyProX-style CSV file" - ) - parser.add_argument( - "output", type=Path, - help="Destination for LyProX-style output file including the consensus" + "output", + type=Path, + help="Destination for LyProX-style output file including the consensus", ) parser.add_argument( - "-c", "--consensus", nargs="+", default=["max_llh"], + "-c", + "--consensus", + nargs="+", + default=["max_llh"], choices=CONSENSUS_FUNCS.keys(), - help="Choose consensus method(s)" + help="Choose consensus method(s)", ) parser.add_argument( - "-p", "--params", default="params.yaml", type=Path, - help="Path to parameter file" + "-p", + "--params", + default="params.yaml", + type=Path, + help="Path to parameter file", ) parser.add_argument( - "--modalities", nargs="+", + "--modalities", + nargs="+", default=["CT", "MRI", "PET", "FNA", "diagnostic_consensus", "pathology", "pCT"], - help="List of modalities for enhancement. Must be defined in `params.yaml`" + help="List of modalities for enhancement. Must be defined in `params.yaml`", ) parser.add_argument( - "--sublvls", nargs="+", default=["a", "b"], - help="Indicate what kinds of sublevels exist" + "--sublvls", + nargs="+", + default=["a", "b"], + help="Indicate what kinds of sublevels exist", ) parser.add_argument( - "--lnls-with-sub", nargs="+", default=["I", "II", "V"], - help="List of LNLs where sublevel reporting has been performed or is common" + "--lnls-with-sub", + nargs="+", + default=["I", "II", "V"], + help="List of LNLs where sublevel reporting has been performed or is common", ) parser.set_defaults(run_main=main) @@ -90,13 +99,11 @@ def get_sublvl_values( lnl: str, sub_ids: list[str], ): - """ - Get values of sublevels (e.g. 'IIa' and 'IIb') for an ``lnl`` and ``data_frame``. - """ - has_sublvls = all(lnl+sub in data_frame for sub in sub_ids) + """Get values of sublevels (e.g. 'IIa' and 'IIb') for an ``lnl`` and ``data_frame``.""" + has_sublvls = all(lnl + sub in data_frame for sub in sub_ids) if not has_sublvls: return None - return data_frame[[lnl+sub for sub in sub_ids]].values + return data_frame[[lnl + sub for sub in sub_ids]].values @log_state() @@ -124,25 +131,25 @@ def infer_superlvl_from_sublvls( for mod in modalities: for side in ["ipsi", "contra"]: for lnl in lnls_with_sub: - sublvl_values = get_sublvl_values( - table[mod,side], lnl, sublvls - ) + sublvl_values = get_sublvl_values(table[mod, side], lnl, sublvls) if sublvl_values is None: continue # sometimes, the sublevels both report `False` (healthy) but the # superlevel is involved. In this case, we want to keep the superlevel # as involved. - if lnl in table[mod,side]: - is_superlvl_involved = table[mod,side,lnl] == True + if lnl in table[mod, side]: + is_superlvl_involved = table[mod, side, lnl] == True else: is_superlvl_involved = False - has_sublvl_involved = np.any(sublvl_values==True , axis=1) - all_sublvls_healthy = np.all(sublvl_values==False, axis=1) + has_sublvl_involved = np.any(sublvl_values == True, axis=1) + all_sublvls_healthy = np.all(sublvl_values == False, axis=1) - fixed_table.loc[has_sublvl_involved, (mod,side,lnl)] = True - fixed_table.loc[all_sublvls_healthy & ~is_superlvl_involved, (mod,side,lnl)] = False + fixed_table.loc[has_sublvl_involved, (mod, side, lnl)] = True + fixed_table.loc[ + all_sublvls_healthy & ~is_superlvl_involved, (mod, side, lnl) + ] = False return fixed_table @@ -161,7 +168,7 @@ def get_lnl_observations( for mod in modalities.keys(): try: - add_obs = patient[mod,side,lnl] + add_obs = patient[mod, side, lnl] add_obs = None if pd.isna(add_obs) else add_obs except KeyError: add_obs = None @@ -191,40 +198,32 @@ def and_consensus(obs_tuple: tuple[np.ndarray]): if has_all_none(obs_tuple): return None - return not( - any(not(obs) if obs is not None else None for obs in obs_tuple) - ) + return not (any(not (obs) if obs is not None else None for obs in obs_tuple)) @lru_cache -def maxllh_consensus( - obs_tuple: tuple[np.ndarray], - modalities_spsn: tuple[list[float]] -): +def maxllh_consensus(obs_tuple: tuple[np.ndarray], modalities_spsn: tuple[list[float]]): """Compute the maximum likelihood consensus of different diagnostic modalities.""" if has_all_none(obs_tuple): return None - healthy_llh = 1. - involved_llh = 1. - for obs, spsn in zip(obs_tuple, modalities_spsn): + healthy_llh = 1.0 + involved_llh = 1.0 + for obs, spsn in zip(obs_tuple, modalities_spsn, strict=False): if obs is None: continue spsn = np.array(spsn) obs = int(obs) - spsn2x2 = np.diag(spsn) + np.diag(1. - spsn)[::-1] - healthy_llh *= spsn2x2[obs,0] - involved_llh *= spsn2x2[obs,1] + spsn2x2 = np.diag(spsn) + np.diag(1.0 - spsn)[::-1] + healthy_llh *= spsn2x2[obs, 0] + involved_llh *= spsn2x2[obs, 1] healthy_vs_involved = np.array([healthy_llh, involved_llh]) return bool(np.argmax(healthy_vs_involved)) @lru_cache -def rank_consensus( - obs_tuple: tuple[np.ndarray], - modalities_spsn: tuple[list[float]] -): +def rank_consensus(obs_tuple: tuple[np.ndarray], modalities_spsn: tuple[list[float]]): """Compute the ranked consensus of different diagnostic modalities.""" if has_all_none(obs_tuple): return None @@ -232,12 +231,12 @@ def rank_consensus( modalities_spsn = list(modalities_spsn) healthy_sens = [ - modalities_spsn[i][1] for i,obs in enumerate(obs_tuple) if obs == False + modalities_spsn[i][1] for i, obs in enumerate(obs_tuple) if obs == False ] involved_spec = [ - modalities_spsn[i][0] for i,obs in enumerate(obs_tuple) if obs == True + modalities_spsn[i][0] for i, obs in enumerate(obs_tuple) if obs == True ] - if np.max([*healthy_sens, 0.]) > np.max([*involved_spec, 0.]): + if np.max([*healthy_sens, 0.0]) > np.max([*involved_spec, 0.0]): return False return True @@ -261,20 +260,18 @@ def main(args: argparse.Namespace): selection=args.modalities, ) - available_mod_keys = sorted(set( - input_table.columns.get_level_values(0) - ).intersection( - modalities.keys() - )) + available_mod_keys = sorted( + set(input_table.columns.get_level_values(0)).intersection(modalities.keys()) + ) available_mods = {key: modalities[key] for key in available_mod_keys} - lnl_union = sorted(set().union( - *[input_table[mod,"ipsi"].columns for mod in available_mod_keys] - )) + lnl_union = sorted( + set().union(*[input_table[mod, "ipsi"].columns for mod in available_mod_keys]) + ) consensus = pd.DataFrame( index=input_table.index, columns=pd.MultiIndex.from_product( [args.consensus, ["ipsi", "contra"], lnl_union] - ) + ), ) with CustomProgress(console=console) as report_progress: @@ -284,7 +281,7 @@ def main(args: argparse.Namespace): ) for side in ["ipsi", "contra"]: # go through patients and LNLs and compute consensus for each - for p,patient in input_table.iterrows(): + for p, patient in input_table.iterrows(): for lnl in lnl_union: observations = get_lnl_observations( patient, side, lnl, available_mods @@ -296,12 +293,11 @@ def main(args: argparse.Namespace): report_progress.update(enhance_task, advance=1) table_with_consensus = input_table.join(consensus) - - data_modalities = sorted(set( - table_with_consensus.columns.get_level_values(0) - ).intersection( - [*modalities.keys(), *args.consensus] - )) + data_modalities = sorted( + set(table_with_consensus.columns.get_level_values(0)).intersection( + [*modalities.keys(), *args.consensus] + ) + ) consensus_and_fixed_sublvlvs = infer_superlvl_from_sublvls( table_with_consensus, data_modalities, diff --git a/lyscripts/data/filter.py b/lyscripts/data/filter.py index 7eb7cb0..df5b6a7 100644 --- a/lyscripts/data/filter.py +++ b/lyscripts/data/filter.py @@ -1,7 +1,7 @@ -""" -Filter a datset according to some common criteria, like tumor location, subsite, +"""Filter a datset according to some common criteria, like tumor location, subsite, T-category, etc. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -44,19 +44,20 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "input", type=Path, - help="The path to the full dataset to split." + "input", type=Path, help="The path to the full dataset to split." ) parser.add_argument( - "output", type=Path, - help="Folder to store the split CSV files in." + "output", type=Path, help="Folder to store the split CSV files in." ) for prefix in ["include", "exclude"]: for filter_by in ["locations", "subsites", "t_categories"]: parser.add_argument( - f"--{prefix}-{filter_by}", default=None, type=str, nargs="+", - help=f"If provided, {prefix} patients with the given tumor {filter_by}." + f"--{prefix}-{filter_by}", + default=None, + type=str, + nargs="+", + help=f"If provided, {prefix} patients with the given tumor {filter_by}.", ) parser.set_defaults(run_main=main) diff --git a/lyscripts/data/generate.py b/lyscripts/data/generate.py index 36c5054..2ce6109 100644 --- a/lyscripts/data/generate.py +++ b/lyscripts/data/generate.py @@ -1,8 +1,8 @@ -""" -Calls the synthetic data generating methods of the `lymph`_ package models. +"""Calls the synthetic data generating methods of the `lymph`_ package models. .. _lymph: https://lymph-model.readthedocs.io """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -34,32 +34,42 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "num", type=int, + "num", + type=int, help="Number of synthetic patient records to generate", ) parser.add_argument( - "output", type=Path, + "output", + type=Path, help="Path where to store the generated synthetic data", ) parser.add_argument( - "--params", default="./params.yaml", type=Path, - help="Parameter file containing model specifications" + "--params", + default="./params.yaml", + type=Path, + help="Parameter file containing model specifications", ) group = parser.add_mutually_exclusive_group() group.add_argument( - "--set-theta", nargs="+", type=float, - help="Set the spread probs and parameters for time marginalization by hand" + "--set-theta", + nargs="+", + type=float, + help="Set the spread probs and parameters for time marginalization by hand", ) group.add_argument( - "--load-theta", choices=["mean", "max_llh"], default="mean", - help="Use either the mean or the maximum likelihood estimate from drawn samples" + "--load-theta", + choices=["mean", "max_llh"], + default="mean", + help="Use either the mean or the maximum likelihood estimate from drawn samples", ) parser.add_argument( - "--samples", default="./models/samples.hdf5", type=Path, - help="Path to the samples if a method to load them was chosen" + "--samples", + default="./models/samples.hdf5", + type=Path, + help="Path to the samples if a method to load them was chosen", ) parser.set_defaults(run_main=main) @@ -82,11 +92,7 @@ def main(args: argparse.Namespace): logger.info("Assigned given parameters to model") else: - backend = emcee.backends.HDFBackend( - args.samples, - read_only=True, - name="mcmc" - ) + backend = emcee.backends.HDFBackend(args.samples, read_only=True, name="mcmc") chain = backend.get_chain(flat=True) log_probs = backend.get_blobs(flat=True) diff --git a/lyscripts/data/join.py b/lyscripts/data/join.py index ef91e81..7c40d0d 100644 --- a/lyscripts/data/join.py +++ b/lyscripts/data/join.py @@ -1,6 +1,5 @@ -""" -Join datasets from different sources (but of the same format) into one. -""" +"""Join datasets from different sources (but of the same format) into one.""" + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -35,12 +34,19 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "-i", "--inputs", nargs='+', type=Path, required=True, - help="List of paths to inference-ready CSV datasets to concatenate." + "-i", + "--inputs", + nargs="+", + type=Path, + required=True, + help="List of paths to inference-ready CSV datasets to concatenate.", ) parser.add_argument( - "-o", "--output", type=Path, required=True, - help="Location to store the concatenated CSV file." + "-o", + "--output", + type=Path, + required=True, + help="Location to store the concatenated CSV file.", ) parser.set_defaults(run_main=main) @@ -52,9 +58,7 @@ def load_and_join_tables(input_paths: list[Path]): for path in input_paths: input_table = load_patient_data(path).convert_dtypes() concatenated_table = pd.concat( - [concatenated_table, input_table], - axis="index", - ignore_index=True + [concatenated_table, input_table], axis="index", ignore_index=True ) logger.info(f"+ concatenated data from {path}") return concatenated_table diff --git a/lyscripts/data/lyproxify.py b/lyscripts/data/lyproxify.py index ad11b46..8d0a17e 100644 --- a/lyscripts/data/lyproxify.py +++ b/lyscripts/data/lyproxify.py @@ -1,5 +1,4 @@ -""" -Consumes raw data and transforms it into a CSV of the format that `LyProX`_ understands. +"""Consumes raw data and transforms it into a CSV of the format that `LyProX`_ understands. To do so, it needs a dictionary that defines a mapping from raw columns to the LyProX style data format. See the documentation of the :py:func:`.transform_to_lyprox` function @@ -7,6 +6,7 @@ .. _LyProX: https://lyprox.org """ + # pylint: disable=logging-fstring-interpolation import argparse import importlib.util @@ -44,38 +44,54 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "-i", "--input", type=Path, required=True, - help="Location of raw CSV data." + "-i", "--input", type=Path, required=True, help="Location of raw CSV data." ) parser.add_argument( - "-r", "--header-rows", nargs="+", default=[0], type=int, - help="List with header row indices of raw file." + "-r", + "--header-rows", + nargs="+", + default=[0], + type=int, + help="List with header row indices of raw file.", ) parser.add_argument( - "-o", "--output", type=Path, required=True, - help="Location to store the lyproxified CSV file." + "-o", + "--output", + type=Path, + required=True, + help="Location to store the lyproxified CSV file.", ) parser.add_argument( - "-m", "--mapping", type=Path, required=True, + "-m", + "--mapping", + type=Path, + required=True, help=( "Location of the Python file that contains column mapping instructions. " "This must contain a dictionary with the name 'column_map'." - ) + ), ) parser.add_argument( - "--drop-rows", nargs="+", type=int, default=[], + "--drop-rows", + nargs="+", + type=int, + default=[], help=( "Delete rows of specified indices. Counting of rows start at 0 _after_ " "the `header-rows`." - ) + ), ) parser.add_argument( - "--drop-cols", nargs="+", type=int, default=[], + "--drop-cols", + nargs="+", + type=int, + default=[], help="Delete columns of specified indices.", ) parser.add_argument( - "--add-index", action="store_true", - help="If the data doesn't contain an index, add one by enumerating the patients" + "--add-index", + action="store_true", + help="If the data doesn't contain an index, add one by enumerating the patients", ) parser.set_defaults(run_main=main) @@ -170,8 +186,7 @@ def generate_markdown_docs( @log_state() def transform_to_lyprox( - raw: pd.DataFrame, - column_map: dict[tuple, dict[str, Any]] + raw: pd.DataFrame, column_map: dict[tuple, dict[str, Any]] ) -> pd.DataFrame: """Transform ``raw`` data frame into table that can be uploaded directly to `LyProX`_. @@ -249,21 +264,15 @@ def leftright_to_ipsicontra(data: pd.DataFrame): involvement. """ len_before = len(data) - left_data = data.loc[ - data["tumor", "1", "side"] != "right" - ] - right_data = data.loc[ - data["tumor", "1", "side"] == "right" - ] + left_data = data.loc[data["tumor", "1", "side"] != "right"] + right_data = data.loc[data["tumor", "1", "side"] == "right"] left_data = left_data.rename(columns={"left": "ipsi"}, level=1) left_data = left_data.rename(columns={"right": "contra"}, level=1) right_data = right_data.rename(columns={"left": "contra"}, level=1) right_data = right_data.rename(columns={"right": "ipsi"}, level=1) - data = pd.concat( - [left_data, right_data], ignore_index=True - ) + data = pd.concat([left_data, right_data], ignore_index=True) assert len_before == len(data), "Number of patients changed" return data @@ -295,7 +304,9 @@ def exclude_patients(raw: pd.DataFrame, exclude: list[tuple[str, Any]]): def main(args: argparse.Namespace): """Run the lyproxify main function.""" raw: pd.DataFrame = load_patient_data(args.input, header=args.header_rows) - raw = clean_header(raw, num_cols=raw.shape[1], num_header_rows=len(args.header_rows)) + raw = clean_header( + raw, num_cols=raw.shape[1], num_header_rows=len(args.header_rows) + ) cols_to_drop = raw.columns[args.drop_cols] trimmed = raw.drop(cols_to_drop, axis="columns") diff --git a/lyscripts/data/split.py b/lyscripts/data/split.py index 0f224ea..a7206d1 100644 --- a/lyscripts/data/split.py +++ b/lyscripts/data/split.py @@ -1,7 +1,7 @@ -""" -Split the full dataset into cross-validation folds according to the +"""Split the full dataset into cross-validation folds according to the content of the params.yaml file. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -34,17 +34,18 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "input", type=Path, - help="The path to the full dataset to split." + "input", type=Path, help="The path to the full dataset to split." ) parser.add_argument( - "output", type=Path, - help="Folder to store the split CSV files in." + "output", type=Path, help="Folder to store the split CSV files in." ) parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter YAML file." + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter YAML file.", ) parser.set_defaults(run_main=main) @@ -58,25 +59,20 @@ def main(args: argparse.Namespace): args.output.mkdir(exist_ok=True) shuffled_df = concatenated_df.sample( - frac=1., - replace=False, - random_state=params["cross-validation"]["seed"] + frac=1.0, replace=False, random_state=params["cross-validation"]["seed"] ).reset_index(drop=True) - split_dfs = np.array_split( - shuffled_df, - len(params["cross-validation"]["folds"]) - ) + split_dfs = np.array_split(shuffled_df, len(params["cross-validation"]["folds"])) for fold_id, split_pattern in params["cross-validation"]["folds"].items(): # Concatenate training and evaluation DataFrames from the split DataFrames # using the split pattern defined in the params file. eval_df = pd.concat( - [split_dfs[k] for k,sym in enumerate(split_pattern) if sym == "e"], - ignore_index=True + [split_dfs[k] for k, sym in enumerate(split_pattern) if sym == "e"], + ignore_index=True, ) train_df = pd.concat( - [split_dfs[k] for k,sym in enumerate(split_pattern) if sym == "t"], - ignore_index=True + [split_dfs[k] for k, sym in enumerate(split_pattern) if sym == "t"], + ignore_index=True, ) eval_df.to_csv(args.output / f"{fold_id}_eval.csv", index=None) diff --git a/lyscripts/data/utils.py b/lyscripts/data/utils.py index a1e724c..32803b7 100644 --- a/lyscripts/data/utils.py +++ b/lyscripts/data/utils.py @@ -1,6 +1,5 @@ -""" -Utilities related to the commands for data cleaning and processing. -""" +"""Utilities related to the commands for data cleaning and processing.""" + from pathlib import Path import pandas as pd diff --git a/lyscripts/decorators.py b/lyscripts/decorators.py index bf6f292..11ed218 100644 --- a/lyscripts/decorators.py +++ b/lyscripts/decorators.py @@ -1,9 +1,9 @@ -""" -This module provides decorators that can be used to avoid repetitive snippets of code, +"""This module provides decorators that can be used to avoid repetitive snippets of code, e.g. safely opening files or logging the state of a function call. This is *not* a command line tool. """ + import functools import logging from collections.abc import Callable @@ -66,10 +66,12 @@ def log_state(log_level: int = logging.INFO) -> Callable: The log message will simply be the function name where underscores are replaced with spaces. The `log_level` can be set in the decorator call. """ + # pylint: disable=logging-fstring-interpolation # pylint: disable=logging-not-lazy def log_decorator(func: Callable): """The decorator wrapping the decorated function.""" + @functools.wraps(func) def wrapper(*args, **kwargs): """The wrapper around the decorated function.""" @@ -101,6 +103,7 @@ def wrapper(*args, **kwargs): def check_input_file_exists(loading_func: Callable) -> Callable: """Check if the file path provided to the `loading_func` exists.""" + @wraps(loading_func) def inner(file_path: str, *args, **kwargs) -> Any: """Wrapped loading function.""" @@ -115,6 +118,7 @@ def inner(file_path: str, *args, **kwargs) -> Any: def check_output_dir_exists(saving_func: Callable) -> Callable: """Make sure the parent directory of the saved file exists.""" + @wraps(saving_func) def inner(file_path: str, *args, **kwargs) -> Any: """Wrapped saving function.""" diff --git a/lyscripts/evaluate.py b/lyscripts/evaluate.py index bd1fbec..1fb8522 100644 --- a/lyscripts/evaluate.py +++ b/lyscripts/evaluate.py @@ -1,8 +1,8 @@ -""" -Evaluate the performance of the trained model by computing quantities like the +"""Evaluate the performance of the trained model by computing quantities like the Bayesian information criterion (BIC) or (if thermodynamic integration was performed) the actual evidence (with error) of the model. """ + # pylint: disable=logging-fstring-interpolation import argparse import json @@ -40,37 +40,29 @@ def _add_arguments(parser: argparse.ArgumentParser): This is called by the parent module that is called via the command line. """ parser.add_argument( - "data", type=Path, - help="Path to the tables of patient data (CSV)." - ) - parser.add_argument( - "model", type=Path, - help="Path to model output files (HDF5)." + "data", type=Path, help="Path to the tables of patient data (CSV)." ) + parser.add_argument("model", type=Path, help="Path to model output files (HDF5).") parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file" + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file", ) parser.add_argument( - "--plots", default="./plots", type=Path, - help="Directory for storing plots" + "--plots", default="./plots", type=Path, help="Directory for storing plots" ) parser.add_argument( - "--metrics", default="./metrics.json", type=Path, - help="Path to metrics file" + "--metrics", default="./metrics.json", type=Path, help="Path to metrics file" ) parser.set_defaults(run_main=main) -def comp_bic( - log_probs: np.ndarray, - num_params: int, - num_data: int -) -> float: - """ - Compute the negative one half of the Bayesian Information Criterion (BIC). +def comp_bic(log_probs: np.ndarray, num_params: int, num_data: int) -> float: + """Compute the negative one half of the Bayesian Information Criterion (BIC). The BIC is defined as [^1] $$ BIC = k \\ln{n} - 2 \\ln{\\hat{L}} $$ @@ -83,7 +75,8 @@ def comp_bic( [^1]: https://en.wikipedia.org/wiki/Bayesian_information_criterion """ - return np.max(log_probs) - num_params * np.log(num_data) / 2. + return np.max(log_probs) - num_params * np.log(num_data) / 2.0 + def compute_evidence( temp_schedule: np.ndarray, @@ -102,7 +95,7 @@ def compute_evidence( integrals = np.zeros(shape=num) for i in range(num): rand_idx = np.random.choice(log_probs.shape[1], size=log_probs.shape[0]) - drawn_accuracy = log_probs[np.arange(log_probs.shape[0]),rand_idx].copy() + drawn_accuracy = log_probs[np.arange(log_probs.shape[0]), rand_idx].copy() integrals[i] = trapezoid(y=drawn_accuracy, x=temp_schedule) return np.mean(integrals), np.std(integrals) @@ -167,17 +160,18 @@ def main(args: argparse.Namespace): args.plots.parent.mkdir(exist_ok=True) beta_vs_accuracy = pd.DataFrame( - np.array([ - temp_schedule, - np.mean(ti_log_probs, axis=1), - np.std(ti_log_probs, axis=1) - ]).T, + np.array( + [ + temp_schedule, + np.mean(ti_log_probs, axis=1), + np.std(ti_log_probs, axis=1), + ] + ).T, columns=["β", "accuracy", "std"], ) beta_vs_accuracy.to_csv(args.plots, index=False) logger.info(f"Plotted β vs accuracy at {args.plots}") - # use blobs, because also for TI, this is the unscaled log-prob backend = emcee.backends.HDFBackend(args.model, read_only=True, name="mcmc") final_log_probs = backend.get_blobs() @@ -188,7 +182,9 @@ def main(args: argparse.Namespace): args.metrics.touch(exist_ok=True) metrics["BIC"] = comp_bic( - final_log_probs, ndim, len(data), + final_log_probs, + ndim, + len(data), ) metrics["max_llh"] = np.max(final_log_probs) metrics["mean_llh"] = np.mean(final_log_probs) diff --git a/lyscripts/mixture_fit.py b/lyscripts/mixture_fit.py index 01922c3..7dd30c3 100644 --- a/lyscripts/mixture_fit.py +++ b/lyscripts/mixture_fit.py @@ -1,28 +1,25 @@ -""" -Learn the spread probabilities of the HMM for lymphatic tumor progression using +"""Learn the spread probabilities of the HMM for lymphatic tumor progression using the preprocessed data as input and the mixture model. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging import os -import pickle from concurrent.futures import ProcessPoolExecutor - from pathlib import Path import numpy as np import pandas as pd -from lymph import models from lymixture.em import expectation, maximization - +from lymph import models from lyscripts.utils import ( + assign_modalities, create_mixture, load_patient_data, load_yaml_params, to_numpy, - assign_modalities ) logger = logging.getLogger(__name__) @@ -47,59 +44,82 @@ def _add_arguments(parser: argparse.ArgumentParser): This is called by the parent module that is called via the command line. """ + parser.add_argument("-i", "--input", type=Path, help="Path to training data files") parser.add_argument( - "-i", "--input", type=Path, required=True, - help="Path to training data files" - ) - parser.add_argument( - "--history", type=Path, nargs="?", - help="Path to store history in (as CSV file)." + "--history", + type=Path, + nargs="?", + help="Path to store history in (as CSV file).", ) parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file." + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", ) parser.add_argument( - "-s", "--seed", type=int, default=42, - help="Seed value to reproduce the same sampling round." + "-s", + "--seed", + type=int, + default=42, + help="Seed value to reproduce the same sampling round.", ) parser.add_argument( - "-m", "--multi_fit", type=bool, default=False, - help="Whether to fit multiple models for later uncertainty evaluation" + "-m", + "--multi_fit", + type=bool, + default=False, + help="Whether to fit multiple models for later uncertainty evaluation", ) parser.add_argument( - "-sp", "--starting_point", type=Path, default=None, - help="Starting point for optimization if we do not want to start from a random point" + "-sp", + "--starting_point", + type=Path, + default=None, + help="Starting point for optimization if we do not want to start from a random point", ) - parser.set_defaults(run_main=main) MIXTURE = None + def log_prob_fn() -> float: - """log probability function using global variables because of pickling.""" - return MIXTURE.likelihood(use_complete = True, given_resps = MIXTURE.get_resps(norm = True)) + """Log probability function using global variables because of pickling.""" + return MIXTURE.likelihood( + use_complete=True, given_resps=MIXTURE.get_resps(norm=True) + ) -def check_convergence(params_history, likelihood_history, steps_back_list, absolute_tolerance = 0.01): +def check_convergence( + params_history, likelihood_history, steps_back_list, absolute_tolerance=0.01 +): current_params = params_history[-1] current_likelihood = likelihood_history[-1] for steps_back in steps_back_list: previous_params = params_history[-steps_back - 1] if np.allclose(to_numpy(current_params), to_numpy(previous_params)): - logger.info(f"Converged after {len(params_history)} steps. due to parameter similarity") + logger.info( + f"Converged after {len(params_history)} steps. due to parameter similarity" + ) return True # Return True if any of the steps is close - elif np.isclose(current_likelihood, likelihood_history[-steps_back - 1],rtol = 0, atol = absolute_tolerance): - logger.info(f"Converged after {len(params_history)} steps. due to likelihood similarity") + elif np.isclose( + current_likelihood, + likelihood_history[-steps_back - 1], + rtol=0, + atol=absolute_tolerance, + ): + logger.info( + f"Converged after {len(params_history)} steps. due to likelihood similarity" + ) return True return False -def run_EM(tolerance, history_dir = None): - """Run the EM algorithm to determine the optimal parameters. - """ +def run_EM(tolerance, history_dir=None): + """Run the EM algorithm to determine the optimal parameters.""" os.makedirs(history_dir, exist_ok=True) is_converged = False iteration = 0 @@ -107,140 +127,186 @@ def run_EM(tolerance, history_dir = None): params_history = [] likelihood_history = [] params_history.append(params.copy()) - likelihood_history.append(MIXTURE.likelihood(use_complete = False)) + likelihood_history.append(MIXTURE.likelihood(use_complete=False)) # Number of steps to look back for convergence look_back_steps = 3 while not is_converged: logger.info(f"Iteration: {iteration}") logger.info(f"Likelihood: {likelihood_history[-1]}") - latent = expectation(MIXTURE, params, log = True) + latent = expectation(MIXTURE, params, log=True) MIXTURE.set_resps(np.exp(latent)) params = maximization(MIXTURE, latent) - + # Append current params and likelihood to history params_history.append(params.copy()) - likelihood_history.append(MIXTURE.likelihood(use_complete = False)) + likelihood_history.append(MIXTURE.likelihood(use_complete=False)) if history_dir != None: llh_history = pd.DataFrame(likelihood_history) - llh_history.columns = ['likelihoods'] - llh_history.to_csv(history_dir + '/llh.csv', index=False) + llh_history.columns = ["likelihoods"] + llh_history.to_csv(history_dir + "/llh.csv", index=False) param_history = pd.DataFrame(params_history) - param_history.to_csv(history_dir + '/params.csv', index=False) - MIXTURE.get_mixture_coefs().to_csv(history_dir + '/mixture_coef.csv', index=False) + param_history.to_csv(history_dir + "/params.csv", index=False) + MIXTURE.get_mixture_coefs().to_csv( + history_dir + "/mixture_coef.csv", index=False + ) # Check if converged if iteration >= 3: # Ensure enough history is available - is_converged = check_convergence(params_history, likelihood_history,list(range(1,look_back_steps+1)),tolerance) + is_converged = check_convergence( + params_history, + likelihood_history, + list(range(1, look_back_steps + 1)), + tolerance, + ) iteration += 1 - df = pd.DataFrame.from_dict(MIXTURE.get_params(), orient='index', columns=['value']) - df.to_csv(history_dir + '/optimal_params.csv') + df = pd.DataFrame.from_dict(MIXTURE.get_params(), orient="index", columns=["value"]) + df.to_csv(history_dir + "/optimal_params.csv") return params_history, likelihood_history -def process_dataset(dataset, folder_path, initial_params, model_build_params, index, look_back_steps=3): +def process_dataset( + dataset, folder_path, initial_params, model_build_params, index, look_back_steps=3 +): os.makedirs(folder_path, exist_ok=True) - subpath_optimal_params = 'optimal_params' + subpath_optimal_params = "optimal_params" os.makedirs(os.path.join(folder_path, subpath_optimal_params), exist_ok=True) - subpath_params_history = 'params_history' + subpath_params_history = "params_history" os.makedirs(os.path.join(folder_path, subpath_params_history), exist_ok=True) - subpath_likelihood_history = 'likelihood_history' + subpath_likelihood_history = "likelihood_history" os.makedirs(os.path.join(folder_path, subpath_likelihood_history), exist_ok=True) - logger.info(f"Starting dataset {index}") mixture = create_mixture(model_build_params) mapping = model_build_params["model"].get("mapping", None) if isinstance(mixture.components[0], models.Unilateral): - mixture.load_patient_data(dataset, split_by= model_build_params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) - assign_modalities(model=mixture, config=model_build_params.get("inference_modalities", {})) + mixture.load_patient_data( + dataset, + split_by=model_build_params["model"].get( + "split_by", ("tumor", "1", "subsite") + ), + mapping=mapping, + ) + assign_modalities( + model=mixture, config=model_build_params.get("inference_modalities", {}) + ) else: raise ValueError("Only Unilateral has been implemented so far") - + mixture.set_params(**initial_params) mixture.normalize_mixture_coefs() - tolerance = model_build_params['model'].get('likelihood_tolerance', 0.01) + tolerance = model_build_params["model"].get("likelihood_tolerance", 0.01) mixture.set_params(**initial_params) params = initial_params.copy() - + mixture.normalize_mixture_coefs() params_history = [params.copy()] likelihood_history = [mixture.likelihood(use_complete=False)] - + is_converged = False count = 0 logger.info(f"[Dataset {index}] started") file_prefix = f"dataset_{index}" while not is_converged: - - latent = expectation(mixture, params, log = True) + latent = expectation(mixture, params, log=True) mixture.set_resps(np.exp(latent)) params = maximization(mixture, latent) params_history.append(params.copy()) likelihood_history.append(mixture.likelihood(use_complete=False)) - + llh_history = pd.DataFrame(likelihood_history) - llh_history.columns = ['likelihoods'] - llh_history.to_csv(os.path.join(folder_path, subpath_likelihood_history, f"{file_prefix}_likelihood_history.csv"), index=False) + llh_history.columns = ["likelihoods"] + llh_history.to_csv( + os.path.join( + folder_path, + subpath_likelihood_history, + f"{file_prefix}_likelihood_history.csv", + ), + index=False, + ) param_history = pd.DataFrame(params_history) - param_history.to_csv(os.path.join(folder_path, subpath_params_history, f"{file_prefix}_param_history.csv"), index=False) + param_history.to_csv( + os.path.join( + folder_path, subpath_params_history, f"{file_prefix}_param_history.csv" + ), + index=False, + ) if count >= look_back_steps: - is_converged = check_convergence(params_history, likelihood_history, list(range(1, look_back_steps + 1)), tolerance) - + is_converged = check_convergence( + params_history, + likelihood_history, + list(range(1, look_back_steps + 1)), + tolerance, + ) + count += 1 - logger.info(f"[Dataset {index}] Converged after {count} steps") + logger.info(f"[Dataset {index}] Converged after {count+1} steps") + + df = pd.DataFrame.from_dict(mixture.get_params(), orient="index", columns=["value"]) + df.to_csv( + os.path.join( + folder_path, subpath_optimal_params, f"{file_prefix}_optimal_params.csv" + ) + ) - df = pd.DataFrame.from_dict(mixture.get_params(), orient='index', columns=['value']) - df.to_csv(os.path.join(folder_path, subpath_optimal_params, f"{file_prefix}_optimal_params.csv")) def main(args: argparse.Namespace) -> None: """Main function to run the EM algorithm for a mixture model""" - params = load_yaml_params(args.params) global MIXTURE MIXTURE = create_mixture(params) + original_data = load_patient_data(params["general"]["data"]) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE.components[0], models.Unilateral): + MIXTURE.load_patient_data( + original_data, + split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), + mapping=mapping, + ) + assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) + + else: + raise ValueError("Only Unilateral has been implemented so far") if args.starting_point is None: rng = np.random.default_rng(params["em"].get("seed", 42)) starting_values = {k: rng.uniform() for k in MIXTURE.get_params()} else: logger.info(f"Using starting point from {args.starting_point}") - starting_df = pd.read_csv(args.starting_point, index_col=0) # Use first column as index - starting_values = starting_df['value'].to_dict() + starting_df = pd.read_csv( + args.starting_point, index_col=0 + ) # Use first column as index + starting_values = starting_df["value"].to_dict() if args.multi_fit: datasets = [] - history_dir = params['sampling']['output_path'] + history_dir = params["sampling"]["output_path"] os.makedirs(history_dir, exist_ok=True) - for i in range(params['sampling']['n_bootstraps']): + for i in range(params["sampling"]["n_bootstraps"]): file_path = os.path.join(args.input, f"dataset_resample_{i}.csv") if os.path.exists(file_path): loaded_dataset = pd.read_csv(file_path, header=[0, 1, 2]) datasets.append(loaded_dataset) with ProcessPoolExecutor(max(1, os.cpu_count() - 2)) as executor: futures = [ - executor.submit(process_dataset, dataset, history_dir, starting_values, params, i) + executor.submit( + process_dataset, dataset, history_dir, starting_values, params, i + ) for i, dataset in enumerate(datasets) - ] + ] else: - inference_data = load_patient_data(args.input) - - mapping = params["model"].get("mapping", None) - if isinstance(MIXTURE.components[0], models.Unilateral): - MIXTURE.load_patient_data(inference_data, split_by= params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) - assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) - - else: - raise ValueError("Only Unilateral has been implemented so far") - MIXTURE.set_params(**starting_values) MIXTURE.normalize_mixture_coefs() - tolerance = params['model'].get('likelihood_tolerance', 0.01) - history_dir = params['fitting']['folder_path'] + tolerance = params["model"].get("likelihood_tolerance", 0.01) + history_dir = params["fitting"]["folder_path"] logger.info(f"Saving history to {history_dir}.") - params_history, likelihood_history = run_EM(tolerance = tolerance, history_dir = history_dir) + params_history, likelihood_history = run_EM( + tolerance=tolerance, history_dir=history_dir + ) + if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) diff --git a/lyscripts/mixture_sample_old.py b/lyscripts/mixture_sample_old.py index 11bb9cd..afb00fb 100644 --- a/lyscripts/mixture_sample_old.py +++ b/lyscripts/mixture_sample_old.py @@ -1,31 +1,27 @@ -""" -Learn the spread probabilities of the HMM for lymphatic tumor progression using +"""Learn the spread probabilities of the HMM for lymphatic tumor progression using the preprocessed data as input and the mixture model. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging -import os -from collections import namedtuple try: from multiprocess import Pool except ModuleNotFoundError: - from multiprocessing import Pool + pass from pathlib import Path import pandas as pd +from lymixture.em import expectation, sample_fixed_mixture, sample_model_params from lymph import models -from lymixture.em import sample_fixed_mixture, sample_model_params, expectation - from lyscripts.utils import ( + assign_modalities, create_mixture, load_patient_data, load_yaml_params, - to_numpy, - assign_modalities ) logger = logging.getLogger(__name__) @@ -51,45 +47,54 @@ def _add_arguments(parser: argparse.ArgumentParser): This is called by the parent module that is called via the command line. """ parser.add_argument( - "-m", "--mixture_coefs", type=Path, required=True, - help="File path of mixture coefficients" - ) - parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file." + "-m", + "--mixture_coefs", + type=Path, + required=True, + help="File path of mixture coefficients", ) parser.add_argument( - "-mp", "--model_params", type=Path, required = True, - help="File path of mixture coefficients" + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", ) parser.add_argument( - "--mode", type=str, default = "fixed_mixture", - help = "Mode of sampling. Use either 'fixed_mixture' or 'fixed_latent'" + "-mp", + "--model_params", + type=Path, + required=True, + help="File path of mixture coefficients", ) parser.add_argument( - "-o", "--output", type=Path, - help="Output path for samples" + "--mode", + type=str, + default="fixed_mixture", + help="Mode of sampling. Use either 'fixed_mixture' or 'fixed_latent'", ) + parser.add_argument("-o", "--output", type=Path, help="Output path for samples") parser.add_argument( - "-d", "--data", type=Path, required=True, - help="Path to the data file." + "-d", "--data", type=Path, required=True, help="Path to the data file." ) parser.add_argument( - "-c", "--continue_sampling", type=bool, default = False, - help="Continue sampling from previous run stored in the output backend" + "-c", + "--continue_sampling", + type=bool, + default=False, + help="Continue sampling from previous run stored in the output backend", ) - parser.set_defaults(run_main=main) MIXTURE = None + def main(args: argparse.Namespace) -> None: """Main function to sample parameters for a mixture model""" - params = load_yaml_params(args.params) - model_params = pd.read_csv(args.model_params,header = [0]) + model_params = pd.read_csv(args.model_params, header=[0]) inference_data = load_patient_data(args.data) param_dict = dict(model_params.iloc[-1]) # ugly, but necessary for pickling @@ -99,7 +104,11 @@ def main(args: argparse.Namespace) -> None: mapping = params["model"].get("mapping", None) if isinstance(MIXTURE.components[0], models.Unilateral): side = params["model"].get("side", "ipsi") - MIXTURE.load_patient_data(inference_data, split_by= params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) + MIXTURE.load_patient_data( + inference_data, + split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), + mapping=mapping, + ) assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) else: @@ -108,10 +117,22 @@ def main(args: argparse.Namespace) -> None: MIXTURE.set_params(**param_dict) MIXTURE.set_resps(expectation(MIXTURE, param_dict)) if args.mode == "fixed_mixture": - backend, samples = sample_fixed_mixture(MIXTURE, steps = params["sampling"].get("steps",),filename = str(args.output)+"/fixed_mixture.hdf5", continue_sampling = args.continue_sampling) + backend, samples = sample_fixed_mixture( + MIXTURE, + steps=params["sampling"].get( + "steps", + ), + filename=str(args.output) + "/fixed_mixture.hdf5", + continue_sampling=args.continue_sampling, + ) elif args.mode == "fixed_latent": - backend, samples = sample_model_params(MIXTURE, steps = params["sampling"].get("steps"),filename = str(args.output)+"/fixed_latent.hdf5", continue_sampling = args.continue_sampling) - + backend, samples = sample_model_params( + MIXTURE, + steps=params["sampling"].get("steps"), + filename=str(args.output) + "/fixed_latent.hdf5", + continue_sampling=args.continue_sampling, + ) + if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) diff --git a/lyscripts/plot/__init__.py b/lyscripts/plot/__init__.py index 2f9a238..22343d1 100644 --- a/lyscripts/plot/__init__.py +++ b/lyscripts/plot/__init__.py @@ -1,12 +1,21 @@ -""" -Provide various plotting utilities for displaying results of e.g. the inference +"""Provide various plotting utilities for displaying results of e.g. the inference or prediction process. At the moment, three subcommands are grouped under :py:mod:`.plot`. """ + import argparse from pathlib import Path -from lyscripts.plot import corner, histograms, thermo_int, mixture_plot, simplex_plot, mixture_sampling_plotter +from lyscripts.plot import ( + corner, + histograms, + mixture_comp_uncertainty, + mixture_plot, + mixture_sampling_plotter, + simplex_plot, + thermo_int, +) + def _add_parser( subparsers: argparse._SubParsersAction, @@ -25,4 +34,9 @@ def _add_parser( thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) simplex_plot._add_parser(subparsers, help_formatter=parser.formatter_class) - mixture_sampling_plotter._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_sampling_plotter._add_parser( + subparsers, help_formatter=parser.formatter_class + ) + mixture_comp_uncertainty._add_parser( + subparsers, help_formatter=parser.formatter_class + ) diff --git a/lyscripts/plot/__main__.py b/lyscripts/plot/__main__.py index 81c8e26..9d89828 100644 --- a/lyscripts/plot/__main__.py +++ b/lyscripts/plot/__main__.py @@ -1,7 +1,15 @@ import argparse from lyscripts import RichDefaultHelpFormatter, exit_cli -from lyscripts.plot import corner, histograms, thermo_int, mixture_plot, simplex_plot, mixture_sampling_plotter +from lyscripts.plot import ( + corner, + histograms, + mixture_comp_uncertainty, + mixture_plot, + mixture_sampling_plotter, + simplex_plot, + thermo_int, +) # I need another __main__ guard here, because otherwise pdoc tries to run this if __name__ == "__main__": @@ -20,7 +28,12 @@ thermo_int._add_parser(subparsers, help_formatter=parser.formatter_class) mixture_plot._add_parser(subparsers, help_formatter=parser.formatter_class) simplex_plot._add_parser(subparsers, help_formatter=parser.formatter_class) - mixture_sampling_plotter._add_parser(subparsers, help_formatter=parser.formatter_class) + mixture_sampling_plotter._add_parser( + subparsers, help_formatter=parser.formatter_class + ) + mixture_comp_uncertainty._add_parser( + subparsers, help_formatter=parser.formatter_class + ) args = parser.parse_args() args.run_main(args) diff --git a/lyscripts/plot/corner.py b/lyscripts/plot/corner.py index cffd5d4..1134edf 100644 --- a/lyscripts/plot/corner.py +++ b/lyscripts/plot/corner.py @@ -1,5 +1,4 @@ -""" -Generate a corner plot of the drawn samples. +"""Generate a corner plot of the drawn samples. A corner plot is a combination of 1D and 2D marginals of probability distributions. The library I use for this is built on `matplotlib` and is called @@ -7,6 +6,7 @@ .. _corner: https://github.com/dfm/corner.py """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -37,17 +37,14 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" + parser.add_argument("model", type=Path, help="Path to model output files (HDF5).") + parser.add_argument("output", type=Path, help="Path to output corner plot (SVG).") parser.add_argument( - "model", type=Path, - help="Path to model output files (HDF5)." - ) - parser.add_argument( - "output", type=Path, - help="Path to output corner plot (SVG)." - ) - parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file" + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file", ) parser.set_defaults(run_main=main) diff --git a/lyscripts/plot/histograms.py b/lyscripts/plot/histograms.py index 5b202fc..f090221 100644 --- a/lyscripts/plot/histograms.py +++ b/lyscripts/plot/histograms.py @@ -1,6 +1,5 @@ -""" -Plot computed risks and prevalences into a beautiful histogram. -""" +"""Plot computed risks and prevalences into a beautiful histogram.""" + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -38,29 +37,27 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "input", type=Path, - help="File path of the computed risks or prevalences (HDF5)" - ) - parser.add_argument( - "output", type=Path, - help="Output path for the plot" + "input", type=Path, help="File path of the computed risks or prevalences (HDF5)" ) + parser.add_argument("output", type=Path, help="Output path for the plot") parser.add_argument( - "--names", nargs="+", - help="List of names of computed risks/prevalences to combine into one plot" - ) - parser.add_argument( - "--title", type=str, - help="Title of the plot" + "--names", + nargs="+", + help="List of names of computed risks/prevalences to combine into one plot", ) + parser.add_argument("--title", type=str, help="Title of the plot") parser.add_argument( - "--bins", default=60, type=int, - help="Number of bins to put the computed values into" + "--bins", + default=60, + type=int, + help="Number of bins to put the computed values into", ) parser.add_argument( - "--mplstyle", default="./.mplstyle", type=Path, - help="Path to the MPL stylesheet" + "--mplstyle", + default="./.mplstyle", + type=Path, + help="Path to the MPL stylesheet", ) parser.set_defaults(run_main=main) @@ -73,18 +70,22 @@ def main(args: argparse.Namespace): contents = [] for name in args.names: color = next(COLOR_CYCLE) - contents.append(Histogram.from_hdf5( - filename=args.input, - dataname=name, - color=color, - )) - logger.info(f"Added histogram {name} to figure") - try: - contents.append(BetaPosterior.from_hdf5( + contents.append( + Histogram.from_hdf5( filename=args.input, dataname=name, color=color, - )) + ) + ) + logger.info(f"Added histogram {name} to figure") + try: + contents.append( + BetaPosterior.from_hdf5( + filename=args.input, + dataname=name, + color=color, + ) + ) except KeyError: logger.warning(f"No observation data available for dataset {name}") else: @@ -95,7 +96,7 @@ def main(args: argparse.Namespace): axes=ax, contents=contents, hist_kwargs={"nbins": args.bins}, - percent_lims=(5., 5.) + percent_lims=(5.0, 5.0), ) ax.legend() logger.info("Drawn figure") diff --git a/lyscripts/plot/mixture_comp_uncertainty.py b/lyscripts/plot/mixture_comp_uncertainty.py new file mode 100644 index 0000000..e027e73 --- /dev/null +++ b/lyscripts/plot/mixture_comp_uncertainty.py @@ -0,0 +1,177 @@ +import argparse +import logging +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from lymph import models + +from lyscripts.plot.utils import COLORS, save_figure +from lyscripts.utils import create_mixture, load_patient_data, load_yaml_params + +logger = logging.getLogger(__name__) + + +def _add_parser( + subparsers: argparse._SubParsersAction, + help_formatter, +): + """Add an ``ArgumentParser`` to the subparsers action.""" + parser = subparsers.add_parser( + Path(__file__).name.replace(".py", ""), + description=__doc__, + help=__doc__, + formatter_class=help_formatter, + ) + _add_arguments(parser) + + +def _add_arguments(parser: argparse.ArgumentParser): + """Add arguments to the parser.""" + parser.add_argument( + "--input", type=Path, help="File path of resampled optimal parameters" + ) + parser.add_argument( + "--optimal_params", + type=Path, + help="File path of the initial optimal parameters", + ) + parser.add_argument("--output", type=Path, help="Output path for the plot") + parser.add_argument( + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", + ) + + parser.set_defaults(run_main=main) + + +def get_stats(keys, best_params_list, optima=None, percentile_range=0.68): + means = [] + lower_errors = [] + upper_errors = [] + for index, key in enumerate(keys): + values = np.array([bp[key] for bp in best_params_list]) + values = sorted(values) + n_values = len(values) + n_within_range = int(np.ceil(percentile_range * n_values)) + if optima is not None: + mean = optima[index] + else: + mean = np.mean(values) + distances = [abs(v - mean) for v in values] + sorted_indices = np.argsort(distances) + closest_indices = sorted_indices[:n_within_range] + # Extract the 68% range + percentile_values = [values[i] for i in closest_indices] + percentile_values.sort() + low_err = max(mean - min(percentile_values), 0) + high_err = max(max(percentile_values) - mean, 0) + means.append(mean) + lower_errors.append(low_err) + upper_errors.append(high_err) + return means, [lower_errors, upper_errors] + + +def plot_mixture_comp_uncertainty(initial_params_path, resampling_path): + all_keys = MIXTURE.get_params().keys() + initial_params = pd.read_csv(initial_params_path, index_col=0)["value"].to_dict() + # List all files in the optimal_params_path directory + optimal_params_files = os.listdir(resampling_path) + + # Open each file and load as a dictionary + best_params_list = [] + for fname in optimal_params_files: + fpath = os.path.join(resampling_path, fname) + params = pd.read_csv(fpath, index_col=0)["value"].to_dict() + best_params_list.append(params) + + num_components = len(MIXTURE.components) + mixture_keys = [key for key in all_keys if "coef" in key] + # Dynamically generate keys for each component based on num_components + keys_per_component = [ + [key for key in mixture_keys if f"{i}_C" in key] for i in range(num_components) + ] + + # Get the optimal parameters as means + MIXTURE.set_params(**initial_params) + optima_list = [] + mixture_coefs = MIXTURE.get_mixture_coefs() + for i in range(num_components): + optima_list.append(mixture_coefs.loc[i].to_list()) + + # Compute stats for each component + means_list = [] + yerr_list = [] + for keys, optima in zip(keys_per_component, optima_list, strict=False): + means, yerr = get_stats(keys, best_params_list, optima) + means_list.append(means) + yerr_list.append(yerr) + + color_list = [] + for lister in optima_list: + extreme_subsite = list(MIXTURE.subgroups.keys())[np.argmax(lister)] + if extreme_subsite in ["C02", "C03", "C04", "C05", "C06"]: + color_list.append(COLORS["blue"]) + elif extreme_subsite in ["C01", "C09", "C10"]: + color_list.append(COLORS["green"]) + elif extreme_subsite in ["C12", "C13"]: + color_list.append(COLORS["red"]) + elif extreme_subsite in ["C32.0", "C32.1", "C32.2"]: + color_list.append(COLORS["orange"]) + # Unpack for plotting for a dynamic number of components + plt.figure(figsize=(10, 6)) + x = np.arange(len(keys_per_component[0])) + for idx, (means, yerr, keys, color) in enumerate( + zip(means_list, yerr_list, keys_per_component, color_list, strict=False) + ): + plt.errorbar( + x, + means, + yerr=yerr, + fmt="o", + color=color, + ecolor=color, + capsize=5, + markersize=8, + label=f"component {idx}", + ) + + plt.xticks(x, MIXTURE.subgroups.keys(), rotation=0, ha="right", fontsize=10) + plt.legend(fontsize=10) + plt.ylabel("Values", fontsize=12) + plt.title("mixture Values Assignment", fontsize=14) + plt.grid(axis="y", linestyle="--", alpha=0.7) + plt.tight_layout() + return plt + + +def main(args: argparse.Namespace): + params = load_yaml_params(args.params) + global MIXTURE + MIXTURE = create_mixture(params) + inference_data = load_patient_data(params["general"]["data"]) + + mapping = params["model"].get("mapping", None) + if isinstance(MIXTURE.components[0], models.Unilateral): + MIXTURE.load_patient_data( + inference_data, + split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), + mapping=mapping, + ) + + plot = plot_mixture_comp_uncertainty(args.optimal_params, args.input) + save_figure(args.output, plot, formats=["png", "svg"]) + logger.info(f"Mixture component uncertainty plot saved to {args.output}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + _add_arguments(parser) + + args = parser.parse_args() + args.run_main(args) diff --git a/lyscripts/plot/mixture_plot.py b/lyscripts/plot/mixture_plot.py index 1e92456..c574eff 100644 --- a/lyscripts/plot/mixture_plot.py +++ b/lyscripts/plot/mixture_plot.py @@ -2,14 +2,15 @@ import logging from pathlib import Path -import pandas as pd import matplotlib.pyplot as plt +import pandas as pd from matplotlib.colors import LinearSegmentedColormap -from lyscripts.plot.utils import COLORS, save_figure, SUBSITE_COLORS +from lyscripts.plot.utils import COLORS, SUBSITE_COLORS, save_figure logger = logging.getLogger(__name__) + def _add_parser( subparsers: argparse._SubParsersAction, help_formatter, @@ -26,19 +27,16 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" - parser.add_argument( - "--input", type=Path, - help="File path of mixture coefficients" - ) - parser.add_argument( - "--output", type=Path, - help="Output path for the plot" - ) + parser.add_argument("--input", type=Path, help="File path of mixture coefficients") + parser.add_argument("--output", type=Path, help="Output path for the plot") parser.set_defaults(run_main=main) + def main(args: argparse.Namespace): - tmp = LinearSegmentedColormap.from_list("tmp", [COLORS['green'], COLORS['red']], N=128) + tmp = LinearSegmentedColormap.from_list( + "tmp", [COLORS["green"], COLORS["red"]], N=128 + ) mixture_df = pd.read_csv(args.input) filtered_keys = [key for key in SUBSITE_COLORS if key in mixture_df.columns] mixture_df = mixture_df[filtered_keys] @@ -50,24 +48,30 @@ def main(args: argparse.Namespace): fig, ax = plt.subplots(figsize=(8, 10)) # Display the rotated matrix using imshow - cax = ax.imshow(matrix_rotated.values, cmap=tmp, origin='upper') + cax = ax.imshow(matrix_rotated.values, cmap=tmp, origin="upper") # Loop over the data and create text annotations for i in range(matrix_rotated.shape[0]): # Rows (previously columns) for j in range(matrix_rotated.shape[1]): # Columns (previously rows) value = matrix_rotated.iloc[i, j] - ax.text(j, i, f"{value:.2f}", ha="center", va="center", - color="white", fontsize=12) - + ax.text( + j, + i, + f"{value:.2f}", + ha="center", + va="center", + color="white", + fontsize=12, + ) # Optional: Set axis labels and title ax.set_xticks(range(matrix_rotated.shape[1])) - ax.set_xticklabels(mixture_df.index, fontsize = 12) # Original row labels + ax.set_xticklabels(mixture_df.index, fontsize=12) # Original row labels ax.set_yticks(range(matrix_rotated.shape[0])) - ax.set_yticklabels(mixture_df.columns, fontsize = 12) # Original column labels - ax.set_title("Mixture Coefficients per subsite", fontsize = 16) + ax.set_yticklabels(mixture_df.columns, fontsize=12) # Original column labels + ax.set_title("Mixture Coefficients per subsite", fontsize=16) save_figure(args.output, fig, formats=["png", "svg"]) - logger.info(f"Mixture parameter matrix saved") + logger.info("Mixture parameter matrix saved") if __name__ == "__main__": @@ -76,4 +80,3 @@ def main(args: argparse.Namespace): args = parser.parse_args() args.run_main(args) - diff --git a/lyscripts/plot/mixture_sampling_plotter.py b/lyscripts/plot/mixture_sampling_plotter.py index 5fd7156..1e7d232 100644 --- a/lyscripts/plot/mixture_sampling_plotter.py +++ b/lyscripts/plot/mixture_sampling_plotter.py @@ -1,27 +1,28 @@ import argparse +import json import logging from pathlib import Path -from cycler import cycler -import scipy as sp -import numpy as np import emcee -from lymph import models import matplotlib.pyplot as plt +import numpy as np import pandas as pd -import json -from lyscripts.plot.utils import COLORS, save_figure +import scipy as sp +from cycler import cycler from lymixture.em import _set_params, expectation +from lymph import models +from lyscripts.plot.utils import COLORS, save_figure from lyscripts.utils import ( + assign_modalities, create_mixture, load_patient_data, load_yaml_params, - assign_modalities ) logger = logging.getLogger(__name__) + def _add_parser( subparsers: argparse._SubParsersAction, help_formatter, @@ -39,100 +40,113 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "--input", type=Path, - help="File path with emcee backend of samples" - ) - parser.add_argument( - "--output", type=Path, - help="Output path for the plot" + "--input", type=Path, help="File path with emcee backend of samples" ) + parser.add_argument("--output", type=Path, help="Output path for the plot") parser.add_argument( - "-m", "--mode", type=str, default = "fixed_mixture", - help = "Mode of sampling. Use either 'fixed_mixture' or 'fixed_latent'" + "-m", + "--mode", + type=str, + default="fixed_mixture", + help="Mode of sampling. Use either 'fixed_mixture' or 'fixed_latent'", ) parser.add_argument( - "-s", "--size", type=int, default = 200, - help = "Number of samples to be used for plotting" + "-s", + "--size", + type=int, + default=200, + help="Number of samples to be used for plotting", ) parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file." + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", ) parser.add_argument( - "-d", "--data", type=Path, required=True, - help="Path to patient data." + "-d", "--data", type=Path, required=True, help="Path to patient data." ) parser.add_argument( - "-mp", "--model_params", type=Path, required = True, - help="File path of mixture coefficients" + "-mp", + "--model_params", + type=Path, + required=True, + help="File path of mixture coefficients", ) parser.set_defaults(run_main=main) - -def multiple_plotter(dataset, risk_dictionary_extended, subsite, stage = ''): - hist_cycl = ( - cycler(histtype=["stepfilled", "step"]) - * cycler(color=list(COLORS.values())) - ) - line_cycl = ( - cycler(linestyle=["-", "--"]) - * cycler(color=list(COLORS.values())) + + +def multiple_plotter(dataset, risk_dictionary_extended, subsite, stage=""): + hist_cycl = cycler(histtype=["stepfilled", "step"]) * cycler( + color=list(COLORS.values()) ) + line_cycl = cycler(linestyle=["-", "--"]) * cycler(color=list(COLORS.values())) dataset_staging = dataset.copy() - dataset_staging['tumor','1','t_stage'] = dataset_staging['tumor','1','t_stage'].replace([0,1,2], 'early') - dataset_staging['tumor','1','t_stage'] = dataset_staging['tumor','1','t_stage'].replace([3,4], 'late') - if stage == 'early' or stage == 'late': - data_selected = dataset_staging.loc[(dataset_staging['tumor']['1']['subsite'] == subsite) & (dataset_staging['tumor']['1']['t_stage'] == stage)] + dataset_staging["tumor", "1", "t_stage"] = dataset_staging[ + "tumor", "1", "t_stage" + ].replace([0, 1, 2], "early") + dataset_staging["tumor", "1", "t_stage"] = dataset_staging[ + "tumor", "1", "t_stage" + ].replace([3, 4], "late") + if stage == "early" or stage == "late": + data_selected = dataset_staging.loc[ + (dataset_staging["tumor"]["1"]["subsite"] == subsite) + & (dataset_staging["tumor"]["1"]["t_stage"] == stage) + ] else: - data_selected = dataset_staging.loc[dataset_staging['tumor']['1']['subsite'] == subsite] + data_selected = dataset_staging.loc[ + dataset_staging["tumor"]["1"]["subsite"] == subsite + ] min_value = 0 max_value = 100 - prevalence = {} + prevalence = {} number_of_patients = {} risks = {} risk_dictionary = risk_dictionary_extended[subsite] for key in risk_dictionary.keys(): - prevalence[key] = (data_selected['max_llh']['ipsi'][key] == True).sum() - number_of_patients[key] = data_selected['max_llh']['ipsi'][key].notna().sum() - risks[key] = np.array(risk_dictionary[key])*100 - + prevalence[key] = (data_selected["max_llh"]["ipsi"][key] == True).sum() + number_of_patients[key] = data_selected["max_llh"]["ipsi"][key].notna().sum() + risks[key] = np.array(risk_dictionary[key]) * 100 + num_matches = [prevalence[key] for key in risk_dictionary.keys()] num_totals = [number_of_patients[key] for key in risk_dictionary.keys()] values = [risks[key] for key in risk_dictionary.keys()] hist_kwargs = { - "bins": np.linspace(min_value, max_value, 80), - "density": True, - "alpha": 0.6, - "linewidth": 2., - } - fig, ax = plt.subplots(figsize=(12,4)) + "bins": np.linspace(min_value, max_value, 80), + "density": True, + "alpha": 0.6, + "linewidth": 2.0, + } + fig, ax = plt.subplots(figsize=(12, 4)) x = np.linspace(min_value, max_value, 200) - zipper = zip(values, risk_dictionary.keys(), num_matches, num_totals, hist_cycl, line_cycl) + zipper = zip( + values, + risk_dictionary.keys(), + num_matches, + num_totals, + hist_cycl, + line_cycl, + strict=False, + ) for vals, label, a, n, hstyle, lstyle in zipper: - ax.hist( - vals, - label=label, - **hist_kwargs, - **hstyle - ) + ax.hist(vals, label=label, **hist_kwargs, **hstyle) if not np.isnan(a): - post = sp.stats.beta.pdf(x / 100., a+1, n-a+1) / 100. + post = sp.stats.beta.pdf(x / 100.0, a + 1, n - a + 1) / 100.0 ax.plot(x, post, label=f"{int(a)}/{int(n)}", **lstyle) ax.legend() ax.set_xlabel("probability [%]") - fig.suptitle(f"Risk distributions {stage} for subsite {subsite}",fontsize = 16) + fig.suptitle(f"Risk distributions {stage} for subsite {subsite}", fontsize=16) return fig - -def multiple_plotter_component(risk_dictionary_extended, component, stage = ''): - hist_cycl = ( - cycler(histtype=["stepfilled", "step"]) - * cycler(color=list(COLORS.values())) - ) - line_cycl = ( - cycler(linestyle=["-", "--"]) - * cycler(color=list(COLORS.values())) + + +def multiple_plotter_component(risk_dictionary_extended, component, stage=""): + hist_cycl = cycler(histtype=["stepfilled", "step"]) * cycler( + color=list(COLORS.values()) ) + line_cycl = cycler(linestyle=["-", "--"]) * cycler(color=list(COLORS.values())) min_value = 0 max_value = 100 prevalence = {} @@ -140,38 +154,33 @@ def multiple_plotter_component(risk_dictionary_extended, component, stage = ''): risks = {} risk_dictionary = risk_dictionary_extended[component] for key in risk_dictionary.keys(): - risks[key] = np.array(risk_dictionary[key])*100 - + risks[key] = np.array(risk_dictionary[key]) * 100 + values = [risks[key] for key in risk_dictionary.keys()] hist_kwargs = { - "bins": np.linspace(min_value, max_value, 80), - "density": True, - "alpha": 0.6, - "linewidth": 2., - } - fig, ax = plt.subplots(figsize=(12,4)) + "bins": np.linspace(min_value, max_value, 80), + "density": True, + "alpha": 0.6, + "linewidth": 2.0, + } + fig, ax = plt.subplots(figsize=(12, 4)) x = np.linspace(min_value, max_value, 200) - zipper = zip(values, risk_dictionary.keys(), hist_cycl) + zipper = zip(values, risk_dictionary.keys(), hist_cycl, strict=False) for vals, label, hstyle in zipper: - ax.hist( - vals, - label=label, - **hist_kwargs, - **hstyle - ) + ax.hist(vals, label=label, **hist_kwargs, **hstyle) ax.legend() ax.set_xlabel("probability [%]") - fig.suptitle(f"Risk distributions {stage} for component {component}",fontsize = 16) + fig.suptitle(f"Risk distributions {stage} for component {component}", fontsize=16) return fig - + def main(args: argparse.Namespace): params = load_yaml_params(args.params) inference_data = load_patient_data(args.data) backend = emcee.backends.HDFBackend(args.input) - samples = backend.get_chain(flat = True) - model_params = pd.read_csv(args.model_params,header = [0]) + samples = backend.get_chain(flat=True) + model_params = pd.read_csv(args.model_params, header=[0]) # ugly, but necessary for pickling global MIXTURE @@ -180,7 +189,11 @@ def main(args: argparse.Namespace): mapping = params["model"].get("mapping", None) if isinstance(MIXTURE.components[0], models.Unilateral): side = params["model"].get("side", "ipsi") - MIXTURE.load_patient_data(inference_data, split_by= params["model"].get("split_by", ("tumor", "1", "subsite")), mapping=mapping) + MIXTURE.load_patient_data( + inference_data, + split_by=params["model"].get("split_by", ("tumor", "1", "subsite")), + mapping=mapping, + ) assign_modalities(model=MIXTURE, config=params.get("inference_modalities", {})) else: @@ -198,26 +211,25 @@ def main(args: argparse.Namespace): for component in component_list: component_dictionary_early_full_sampling[str(component)] = { - lnl: [] for lnl in lnls + lnl: [] for lnl in lnls } component_dictionary_late_full_sampling[str(component)] = { lnl: [] for lnl in lnls } - + subsite_dictionary_early_full_sampling = {} subsite_dictionary_late_full_sampling = {} subsite_list = list(MIXTURE.subgroups.keys()) for subsite in subsite_list: subsite_dictionary_early_full_sampling[subsite] = { - lnl: [] for lnl in lnls # Dynamically generate keys from the lnls list - } - subsite_dictionary_late_full_sampling[subsite] = { - lnl: [] for lnl in lnls + lnl: [] + for lnl in lnls # Dynamically generate keys from the lnls list } + subsite_dictionary_late_full_sampling[subsite] = {lnl: [] for lnl in lnls} involvement_dict = {lnl: {lnl: True} for lnl in lnls} - samples_thinned = samples[::int(np.round(len(samples)/args.size,0))] + samples_thinned = samples[:: int(np.round(len(samples) / args.size, 0))] for round, sample in enumerate(samples_thinned): if args.mode == "fixed_latent": MIXTURE.set_params(*sample) @@ -225,28 +237,67 @@ def main(args: argparse.Namespace): _set_params(MIXTURE, sample) else: raise ValueError("Invalid mode") - component_dictionary_early_full_sampling['0']['II'].append(MIXTURE.components[0].risk(involvement = involvement_dict['II'],t_stage = 'early')) + component_dictionary_early_full_sampling["0"]["II"].append( + MIXTURE.components[0].risk( + involvement=involvement_dict["II"], t_stage="early" + ) + ) for component in component_list: for lnl in lnls: - component_dictionary_early_full_sampling[str(component)][lnl].append(MIXTURE.components[component].risk(involvement = involvement_dict[lnl],t_stage = 'early')) - component_dictionary_late_full_sampling[str(component)][lnl].append(MIXTURE.components[component].risk(involvement = involvement_dict[lnl],t_stage = 'late')) - + component_dictionary_early_full_sampling[str(component)][lnl].append( + MIXTURE.components[component].risk( + involvement=involvement_dict[lnl], t_stage="early" + ) + ) + component_dictionary_late_full_sampling[str(component)][lnl].append( + MIXTURE.components[component].risk( + involvement=involvement_dict[lnl], t_stage="late" + ) + ) + for subsite in subsite_list: for lnl in lnls: - subsite_dictionary_early_full_sampling[subsite][lnl].append(MIXTURE.risk(subgroup = subsite, involvement = involvement_dict[lnl],t_stage = 'early')) - subsite_dictionary_late_full_sampling[subsite][lnl].append(MIXTURE.risk(subgroup = subsite, involvement = involvement_dict[lnl],t_stage = 'late')) - print(round, ' done') + subsite_dictionary_early_full_sampling[subsite][lnl].append( + MIXTURE.risk( + subgroup=subsite, + involvement=involvement_dict[lnl], + t_stage="early", + ) + ) + subsite_dictionary_late_full_sampling[subsite][lnl].append( + MIXTURE.risk( + subgroup=subsite, + involvement=involvement_dict[lnl], + t_stage="late", + ) + ) + print(round, " done") for component in component_list: - fig = multiple_plotter_component(component_dictionary_early_full_sampling, str(component), stage = 'early') - save_figure(args.output/f"component_{component}_early", fig, formats = ['png','svg']) - fig = multiple_plotter_component(component_dictionary_late_full_sampling, str(component), stage = 'late') - save_figure(args.output/f"component_{component}_late", fig, formats = ['png','svg']) + fig = multiple_plotter_component( + component_dictionary_early_full_sampling, str(component), stage="early" + ) + save_figure( + args.output / f"component_{component}_early", fig, formats=["png", "svg"] + ) + fig = multiple_plotter_component( + component_dictionary_late_full_sampling, str(component), stage="late" + ) + save_figure( + args.output / f"component_{component}_late", fig, formats=["png", "svg"] + ) for subsite in subsite_list: - fig = multiple_plotter(inference_data, subsite_dictionary_early_full_sampling, subsite, stage = 'early') - save_figure(args.output/f"{subsite}_early", fig, formats = ['png','svg']) - fig = multiple_plotter(inference_data,subsite_dictionary_late_full_sampling, subsite, stage = 'late') - save_figure(args.output/f"{subsite}_late", fig, formats = ['png','svg']) + fig = multiple_plotter( + inference_data, + subsite_dictionary_early_full_sampling, + subsite, + stage="early", + ) + save_figure(args.output / f"{subsite}_early", fig, formats=["png", "svg"]) + fig = multiple_plotter( + inference_data, subsite_dictionary_late_full_sampling, subsite, stage="late" + ) + save_figure(args.output / f"{subsite}_late", fig, formats=["png", "svg"]) # Save dictionary to a JSON file with open("subsite_early.json", "w") as file: @@ -257,10 +308,11 @@ def main(args: argparse.Namespace): json.dump(component_dictionary_early_full_sampling, file) with open("component_late.json", "w") as file: json.dump(component_dictionary_late_full_sampling, file) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) _add_arguments(parser) args = parser.parse_args() args.run_main(args) - diff --git a/lyscripts/plot/simplex_plot.py b/lyscripts/plot/simplex_plot.py index 3db6a91..a67b5f2 100644 --- a/lyscripts/plot/simplex_plot.py +++ b/lyscripts/plot/simplex_plot.py @@ -1,24 +1,26 @@ -""" -Visualize the component assignments of the trained mixture model. -""" - +"""Visualize the component assignments of the trained mixture model.""" import argparse import logging -import numpy as np from pathlib import Path -import pandas as pd import matplotlib.pyplot as plt +import numpy as np +import pandas as pd from matplotlib.lines import Line2D +from matplotlib.ticker import StrMethodFormatter - -from lyscripts.plot.utils import COLORS, SUBSITE_COLORS, save_figure, add_perpendicular_ticks, p_to_xyz, add_perpendicular_crosses_3d +from lyscripts.plot.utils import ( + COLORS, + SUBSITE_COLORS, + add_perpendicular_ticks, + save_figure, +) from lyscripts.utils import load_yaml_params -from matplotlib.ticker import StrMethodFormatter logger = logging.getLogger(__name__) + def _add_parser( subparsers: argparse._SubParsersAction, help_formatter, @@ -36,138 +38,204 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "--input", type=Path, - help="File path of the mixture coefficients" + "--input", type=Path, help="File path of the mixture coefficients" ) + parser.add_argument("--output", type=Path, help="Output path for the plot") parser.add_argument( - "--output", type=Path, - help="Output path for the plot" - ) - parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file." + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", ) parser.set_defaults(run_main=main) -# def plot_2d_simplex(mixture_df, data): -# _, bottom_ax = plt.subplots() -# subsites = list(mixture_df.columns) -# cluster_x = mixture_df.loc[0] -# cluster_y = [0. for _ in subsites] -# annotations = [f"{label}\n({num})" for label, num in num_patients.items()] -# bottom_ax.scatter( -# cluster_x, cluster_y, -# s=[num for num in num_patients.values()], -# c=list(generate_location_colors(subsites)), -# alpha=0.7, -# linewidths=0., -# zorder=10, -# ) - -# sorted_idx = cluster_components.argsort() -# sorted_x = cluster_components[sorted_idx] -# sorted_annotations = [annotations[i] for i in sorted_idx] -# sorted_num = [list(num_patients.values())[i] for i in sorted_idx] -# for i, (x, num, annotation) in enumerate(zip(sorted_x, sorted_num, sorted_annotations)): -# bottom_ax.annotate( -# annotation, -# # sqrt, because marker's area grows linearly with patient num, not radius -# xy=(x, np.sqrt(0.0000003 * num) * (- 1)**i), -# xytext=(x, 0.025 * (- 1)**i), -# ha="center", -# va="bottom" if i % 2 == 0 else "top", -# fontsize="small", -# arrowprops={ -# "arrowstyle": "-", -# "color": USZ["gray"], -# "linewidth": 1., -# } -# ) - -# bottom_ax.set_xlabel("assignment to component A") -# bottom_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) -# top_ax = bottom_ax.secondary_xaxis( -# location="top", -# functions=(lambda x: 1. - x, lambda x: 1. - x), -# ) -# top_ax.set_xlabel("assignment to component B") -# top_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) -# bottom_ax.set_yticks([]) -# bottom_ax.grid(axis="x", alpha=0.5, color=USZ["gray"], linestyle=":") -# plt.savefig(args.output, bbox_inches="tight", dpi=300) - -def plot_2d_simplex(mixture_df, data): - -def plot_3d_simplex(mixture_df, data, output, component_names = False): +def plot_2d_simplex(mixture_df, data, output, component_names=False): + data = pd.read_csv(data, header=[0, 1, 2]) + subsites = list(mixture_df.columns) + colors_ordered = [SUBSITE_COLORS[subsite] for subsite in subsites] + num_patients = dict(data["tumor", "1", "subsite"].value_counts()) + annotations = [f"{key}\n({num_patients[key]})" for key in mixture_df.columns] + sizes = [num_patients[key] for key in mixture_df.columns] + + fig, bottom_ax = plt.subplots(figsize=(8, 2.5)) + bottom_ax.scatter( + mixture_df.loc[0], + np.zeros(len(mixture_df.loc[0])), + s=sizes, + c=colors_ordered, + alpha=0.7, + linewidths=0.0, + zorder=10, + ) + + # Sort annotations by x-position (left to right) + x_positions = mixture_df.loc[0].values + sorted_indices = np.argsort(x_positions) + + for i, idx in enumerate(sorted_indices): + x = x_positions[idx] + num = list(num_patients.values())[idx] + annotation = annotations[idx] + + offset = 0.025 * (-1) ** i + bottom_ax.annotate( + annotation, + xy=(x, np.sqrt(0.0000003 * num) * (-1) ** i), + xytext=(x, offset), + ha="center", + va="bottom" if i % 2 == 0 else "top", + fontsize="small", + arrowprops={ + "arrowstyle": "-", + "color": COLORS["gray"], + "linewidth": 1.0, + }, + ) + + # X-axis formatting + bottom_ax.set_xlabel("assignment to component 1") + bottom_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) + + # Add secondary top axis + top_ax = bottom_ax.secondary_xaxis( + "top", functions=(lambda x: 1.0 - x, lambda x: 1.0 - x) + ) + top_ax.set_xlabel("assignment to component 2") + top_ax.xaxis.set_major_formatter(StrMethodFormatter("{x:.0%}")) + + # Clean up + bottom_ax.set_yticks([]) + bottom_ax.grid(axis="x", alpha=0.5, color=COLORS["gray"], linestyle=":") + + # Adjust layout to prevent clipping of labels + plt.tight_layout() + + save_figure(output, fig, formats=["png", "svg"]) + logger.info("Simplex plot saved") + + +def plot_3d_simplex(mixture_df, data, output, component_names=False): data = pd.read_csv(data, header=[0, 1, 2]) subsites = list(mixture_df.columns) colors_ordered = [SUBSITE_COLORS[subsite] for subsite in subsites] # Define the plane's normal vector - normal_vector = np.array([1,1,1])/np.sqrt(3) + normal_vector = np.array([1, 1, 1]) / np.sqrt(3) - v1 = np.array([1,-1,0])/np.sqrt(2) + v1 = np.array([1, -1, 0]) / np.sqrt(2) # Calculate the second orthogonal vector using the cross product - v2 = np.cross(normal_vector, v1) *-1 + v2 = np.cross(normal_vector, v1) * -1 # Project the point onto the new coordinate system origin = np.array([0, 1, 0]) - x_origin = origin @ v1 + x_origin = origin @ v1 y_origin = origin @ v2 x_vals = mixture_df.T @ v1 - x_origin y_vals = mixture_df.T @ v2 - y_origin - extremes = np.array([[1,0,0], - [0,1,0], - [0,0,1]]) + extremes = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) extremes_x = extremes @ v1 - x_origin extremes_y = extremes @ v2 - y_origin # Plot the point in 2D import matplotlib.pyplot as plt - odered_value_counts = data['tumor']['1']['subsite'].value_counts()[subsites] + odered_value_counts = data["tumor"]["1"]["subsite"].value_counts()[subsites] sizes = odered_value_counts * 3 fig, ax = plt.subplots(figsize=(8, 6.8)) # Plot the points with varying sizes and colors for i in range(len(x_vals)): - ax.scatter(x_vals[i], y_vals[i], s=sizes[i], color=colors_ordered[i], label=subsites[i]) - ax.text(x_vals[i], y_vals[i], subsites[i], fontsize=9, ha='center', va='center') + ax.scatter( + x_vals[i], y_vals[i], s=sizes[i], color=colors_ordered[i], label=subsites[i] + ) + ax.text(x_vals[i], y_vals[i], subsites[i], fontsize=9, ha="center", va="center") legend_text = [] for index in range(len(subsites)): - legend_text.append(subsites[index] + ', ' + str(odered_value_counts[index]) + ' patients') - - legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=subsite) - for color, subsite in zip(colors_ordered, legend_text)] + legend_text.append( + subsites[index] + ", " + str(odered_value_counts[index]) + " patients" + ) + + legend_elements = [ + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor=color, + markersize=10, + label=subsite, + ) + for color, subsite in zip(colors_ordered, legend_text, strict=False) + ] # Add a legend with fixed dot sizes - ax.legend(handles=legend_elements, loc='upper right', title='Subsites', fontsize='small') + ax.legend( + handles=legend_elements, loc="upper right", title="Subsites", fontsize="small" + ) # Connect the points - ax.plot(extremes_x, extremes_y, color='black', alpha=0.5) + ax.plot(extremes_x, extremes_y, color="black", alpha=0.5) # Close the triangle by connecting the last point to the first - ax.plot([extremes_x[-1], extremes_x[0]], [extremes_y[-1], extremes_y[0]], color='black', alpha=0.5) + ax.plot( + [extremes_x[-1], extremes_x[0]], + [extremes_y[-1], extremes_y[0]], + color="black", + alpha=0.5, + ) # Calculate midpoints of each side of the triangle - midpoints_x = (extremes_x[0] + extremes_x[1]) / 2, (extremes_x[1] + extremes_x[2]) / 2, (extremes_x[2] + extremes_x[0]) / 2 - midpoints_y = (extremes_y[0] + extremes_y[1]) / 2, (extremes_y[1] + extremes_y[2]) / 2, (extremes_y[2] + extremes_y[0]) / 2 + midpoints_x = ( + (extremes_x[0] + extremes_x[1]) / 2, + (extremes_x[1] + extremes_x[2]) / 2, + (extremes_x[2] + extremes_x[0]) / 2, + ) + midpoints_y = ( + (extremes_y[0] + extremes_y[1]) / 2, + (extremes_y[1] + extremes_y[2]) / 2, + (extremes_y[2] + extremes_y[0]) / 2, + ) # Draw lines from each vertex to the midpoint of the opposite side - ax.plot([extremes_x[0], midpoints_x[1]], [extremes_y[0], midpoints_y[1]], color='gray', linestyle='--', linewidth=1) - ax.plot([extremes_x[1], midpoints_x[2]], [extremes_y[1], midpoints_y[2]], color='gray', linestyle='--', linewidth=1) - ax.plot([extremes_x[2], midpoints_x[0]], [extremes_y[2], midpoints_y[0]], color='gray', linestyle='--', linewidth=1) + ax.plot( + [extremes_x[0], midpoints_x[1]], + [extremes_y[0], midpoints_y[1]], + color="gray", + linestyle="--", + linewidth=1, + ) + ax.plot( + [extremes_x[1], midpoints_x[2]], + [extremes_y[1], midpoints_y[2]], + color="gray", + linestyle="--", + linewidth=1, + ) + ax.plot( + [extremes_x[2], midpoints_x[0]], + [extremes_y[2], midpoints_y[0]], + color="gray", + linestyle="--", + linewidth=1, + ) # Add perpendicular ticks to each line with adjusted length - add_perpendicular_ticks(extremes_x[0], extremes_y[0], midpoints_x[1], midpoints_y[1]) - add_perpendicular_ticks(extremes_x[1], extremes_y[1], midpoints_x[2], midpoints_y[2]) - add_perpendicular_ticks(extremes_x[2], extremes_y[2], midpoints_x[0], midpoints_y[0]) + add_perpendicular_ticks( + extremes_x[0], extremes_y[0], midpoints_x[1], midpoints_y[1] + ) + add_perpendicular_ticks( + extremes_x[1], extremes_y[1], midpoints_x[2], midpoints_y[2] + ) + add_perpendicular_ticks( + extremes_x[2], extremes_y[2], midpoints_x[0], midpoints_y[0] + ) # Scaling factor to move the text farther from the vertices scaling_factor = 1.1 @@ -182,38 +250,67 @@ def plot_3d_simplex(mixture_df, data, output, component_names = False): if component_names: # Plot the text labels farther away from the triangle - if 'C32.0' in subsites: - component_larynx = mixture_df['C32.0'].argmax() - ax.text(scaled_extremes_x[component_larynx], scaled_extremes_y[component_larynx], "Larynx like", - fontsize=10, ha='right', va='top', c=COLORS['orange']) - - if 'C01' in subsites: - component_oropharynx = mixture_df['C01'].argmax() - ax.text(scaled_extremes_x[component_oropharynx], scaled_extremes_y[component_oropharynx], "Oropharynx like", - fontsize=10, ha='left', va='top', c=COLORS['green']) - - if 'C03' in subsites: - component_oral_cavity = mixture_df['C03'].argmax() - ax.text(scaled_extremes_x[component_oral_cavity], scaled_extremes_y[component_oral_cavity], "Oral cavity like", - fontsize=10, ha='left', va='top', c=COLORS['blue']) - - if 'C13' in subsites: - component_hypopharynx = mixture_df['C13'].argmax() - ax.text(scaled_extremes_x[component_hypopharynx], scaled_extremes_y[component_hypopharynx], "Hypopharynx like", - fontsize=10, ha='left', va='top', c=COLORS['red']) + if "C32.0" in subsites: + component_larynx = mixture_df["C32.0"].argmax() + ax.text( + scaled_extremes_x[component_larynx], + scaled_extremes_y[component_larynx], + "Larynx like", + fontsize=10, + ha="right", + va="top", + c=COLORS["orange"], + ) + + if "C01" in subsites: + component_oropharynx = mixture_df["C01"].argmax() + ax.text( + scaled_extremes_x[component_oropharynx], + scaled_extremes_y[component_oropharynx], + "Oropharynx like", + fontsize=10, + ha="left", + va="top", + c=COLORS["green"], + ) + + if "C03" in subsites: + component_oral_cavity = mixture_df["C03"].argmax() + ax.text( + scaled_extremes_x[component_oral_cavity], + scaled_extremes_y[component_oral_cavity], + "Oral cavity like", + fontsize=10, + ha="left", + va="top", + c=COLORS["blue"], + ) + + if "C13" in subsites: + component_hypopharynx = mixture_df["C13"].argmax() + ax.text( + scaled_extremes_x[component_hypopharynx], + scaled_extremes_y[component_hypopharynx], + "Hypopharynx like", + fontsize=10, + ha="left", + va="top", + c=COLORS["red"], + ) save_figure(output, fig, formats=["png", "svg"]) - logger.info(f"Simplex plot saved") + logger.info("Simplex plot saved") + def main(args: argparse.Namespace): mixture_df = pd.read_csv(args.input) nr_components = len(mixture_df) params = load_yaml_params(args.params) - data = params['general']['data'] + data = params["general"]["data"] if nr_components == 2: - plot_2d_simplex(mixture_df, data) + plot_2d_simplex(mixture_df, data, output=args.output) elif nr_components == 3: - plot_3d_simplex(mixture_df, data, output = args.output) + plot_3d_simplex(mixture_df, data, output=args.output) else: logger.info(f"Simplex not supported for {nr_components} components") @@ -224,4 +321,3 @@ def main(args: argparse.Namespace): args = parser.parse_args() args.run_main(args) - diff --git a/lyscripts/plot/thermo_int.py b/lyscripts/plot/thermo_int.py index e20f507..208840d 100644 --- a/lyscripts/plot/thermo_int.py +++ b/lyscripts/plot/thermo_int.py @@ -1,9 +1,9 @@ -""" -Plot how the accuracy develops over the course of a thermodynamic integration run. +"""Plot how the accuracy develops over the course of a thermodynamic integration run. This can also be used to compare how the accuracy of different models develops during thermdynamic integration. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -43,35 +43,39 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): """Add arguments to the parser.""" parser.add_argument( - "inputs", type=Path, nargs="+", - help="Paths to the CSV files containing the stored TI runs" + "inputs", + type=Path, + nargs="+", + help="Paths to the CSV files containing the stored TI runs", ) group = parser.add_mutually_exclusive_group() group.add_argument( - "-o", "--output", type=Path, - help="Path to where the plot should be stored (PNG and SVG)" + "-o", + "--output", + type=Path, + help="Path to where the plot should be stored (PNG and SVG)", ) group.add_argument( - "--show", action="store_true", - help="Show the plot instead of saving it" + "--show", action="store_true", help="Show the plot instead of saving it" ) + parser.add_argument("--title", default=None, help="Title of the plot") parser.add_argument( - "--title", default=None, - help="Title of the plot" - ) - parser.add_argument( - "--labels", type=str, nargs="+", default=[], - help="Labels for the individual data series" + "--labels", + type=str, + nargs="+", + default=[], + help="Labels for the individual data series", ) parser.add_argument( - "--power", default=5., type=float, - help="Scale the x-axis with this power" + "--power", default=5.0, type=float, help="Scale the x-axis with this power" ) parser.add_argument( - "--mplstyle", default="./.mplstyle", type=Path, - help="Path to the MPL stylesheet" + "--mplstyle", + default="./.mplstyle", + type=Path, + help="Path to the MPL stylesheet", ) parser.set_defaults(run_main=main) @@ -92,27 +96,25 @@ def main(args: argparse.Namespace): logger.info(f"+ read in {input}") logger.info("Loaded CSV file(s)") - fig, ax = plt.subplots(figsize=get_size()) if args.title is not None: fig.suptitle(args.title) ax.set_xlabel("inverse temperature $\\beta$") - xticks = np.linspace(0., 1., 7) + xticks = np.linspace(0.0, 1.0, 7) xticklabels = [f"{x**args.power:.2g}" for x in xticks] ax.set_xticks(ticks=xticks, labels=xticklabels) - ax.set_xlim(left=0., right=1.) + ax.set_xlim(left=0.0, right=1.0) ax.set_ylabel("accuracy $\\mathcal{A}(\\beta)$") ax.set_yscale("symlog") ax.get_yaxis().set_major_locator(matplotlib.ticker.MultipleLocator(800)) ax.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) - ax.ticklabel_format(axis="y", style="sci", scilimits=(2,2)) + ax.ticklabel_format(axis="y", style="sci", scilimits=(2, 2)) logger.info("Prepared figure") - for i, series in enumerate(accuracy_series): - last_acc = series['accuracy'].values[-1] + last_acc = series["accuracy"].values[-1] try: label = args.labels[i] + " $\\mathcal{A}(1)$ = " + f"{last_acc:g}" except IndexError: @@ -120,14 +122,14 @@ def main(args: argparse.Namespace): if "stddev" in series: ax.errorbar( - series["β"]**(1./args.power), + series["β"] ** (1.0 / args.power), series["accuracy"], yerr=series["stddev"], label=label, ) else: ax.plot( - series["β"]**(1./args.power), + series["β"] ** (1.0 / args.power), series["accuracy"], label=label, ) @@ -136,7 +138,6 @@ def main(args: argparse.Namespace): ax.legend() logger.info("Plotted series") - if args.show: plt.show() logger.info("Showed the plot") diff --git a/lyscripts/plot/utils.py b/lyscripts/plot/utils.py index 29e72a2..f0d1da9 100644 --- a/lyscripts/plot/utils.py +++ b/lyscripts/plot/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math from abc import abstractmethod from collections.abc import Mapping from dataclasses import dataclass, field @@ -11,10 +12,9 @@ import h5py import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap import numpy as np import scipy as sp -import math +from matplotlib.colors import LinearSegmentedColormap from lyscripts.decorators import ( check_input_file_exists, @@ -38,31 +38,33 @@ } COLOR_CYCLE = cycle(COLORS.values()) CM_PER_INCH = 2.54 -blue_to_white = LinearSegmentedColormap.from_list("blue to white", - [COLORS['blue'], "#ffffff"], - N=256) -green_to_white = LinearSegmentedColormap.from_list("green_to_white", - [COLORS['green'], "#ffffff"], - N=256) -red_to_white = LinearSegmentedColormap.from_list("red_to_white", - [COLORS['red'], "#ffffff"], - N=256) -orange_to_white = LinearSegmentedColormap.from_list("orange_to_white", - [COLORS['orange'], "#ffffff"], - N=256) -SUBSITE_COLORS = {'C03': blue_to_white(0), - 'C04':blue_to_white(0.15), - 'C06':blue_to_white(0.3), - 'C02':blue_to_white(0.45), - 'C05':blue_to_white(0.6), - 'C10':green_to_white(0), - 'C09':green_to_white(0.3), - 'C01':green_to_white(0.6), - 'C12':red_to_white(0), - 'C13':red_to_white(0.5), - 'C32.1':orange_to_white(0), - 'C32.2':orange_to_white(0.3), - 'C32.0':orange_to_white(0.6)} +blue_to_white = LinearSegmentedColormap.from_list( + "blue to white", [COLORS["blue"], "#ffffff"], N=256 +) +green_to_white = LinearSegmentedColormap.from_list( + "green_to_white", [COLORS["green"], "#ffffff"], N=256 +) +red_to_white = LinearSegmentedColormap.from_list( + "red_to_white", [COLORS["red"], "#ffffff"], N=256 +) +orange_to_white = LinearSegmentedColormap.from_list( + "orange_to_white", [COLORS["orange"], "#ffffff"], N=256 +) +SUBSITE_COLORS = { + "C03": blue_to_white(0), + "C04": blue_to_white(0.15), + "C06": blue_to_white(0.3), + "C02": blue_to_white(0.45), + "C05": blue_to_white(0.6), + "C10": green_to_white(0), + "C09": green_to_white(0.3), + "C01": green_to_white(0.6), + "C12": red_to_white(0), + "C13": red_to_white(0.5), + "C32.1": orange_to_white(0), + "C32.2": orange_to_white(0.3), + "C32.0": orange_to_white(0.6), +} def floor_at_decimal(value: float, decimal: int) -> float: @@ -107,6 +109,7 @@ def clean_and_check(filename: str | Path) -> Path: AbstractDistributionT = TypeVar("AbstractDistributionT", bound="AbstractDistribution") + @dataclass(kw_only=True) class AbstractDistribution: """Abstract class for distributions that should be plotted.""" @@ -437,75 +440,108 @@ def save_figure( for frmt in formats: figure.savefig(output_path.with_suffix(f".{frmt}")) + def p_to_xyz(p): - #Project 4D representation down to 3D representation - s3 = 1/math.sqrt(3.0) - s6 = 1/math.sqrt(6.0) - x = -1*p[0] + 1*p[1] + 0*p[2] + 0*p[3] - y = -s3*p[0] - s3*p[1] + 2*s3*p[2] + 0*p[3] - z = -s3*p[0] - s3*p[1] - s3*p[2] + 3*s6*p[3] + # Project 4D representation down to 3D representation + s3 = 1 / math.sqrt(3.0) + s6 = 1 / math.sqrt(6.0) + x = -1 * p[0] + 1 * p[1] + 0 * p[2] + 0 * p[3] + y = -s3 * p[0] - s3 * p[1] + 2 * s3 * p[2] + 0 * p[3] + z = -s3 * p[0] - s3 * p[1] - s3 * p[2] + 3 * s6 * p[3] return x, y, z + def add_perpendicular_crosses_3d(ax, x1, y1, z1, x2, y2, z2, tick_length=0.03): num_ticks = 6 # Number of ticks for i in range(num_ticks): t = i / (num_ticks - 1) - + # Interpolation point (x, y, z) on the line x_tick = x1 + t * (x2 - x1) y_tick = y1 + t * (y2 - y1) z_tick = z1 + t * (z2 - z1) - + # Vector along the line (direction of the line) line_vec = np.array([x2 - x1, y2 - y1, z2 - z1]) - + # First perpendicular vector (cross product with z-axis) perp_vec1 = np.cross(line_vec, [0, 0, 1]) length1 = np.linalg.norm(perp_vec1) - if length1 == 0: # Prevent division by zero (in rare cases when parallel to z-axis) + if ( + length1 == 0 + ): # Prevent division by zero (in rare cases when parallel to z-axis) perp_vec1 = np.cross(line_vec, [1, 0, 0]) # Cross with x-axis instead length1 = np.linalg.norm(perp_vec1) perp_vec1 /= length1 # Normalize - + # Second perpendicular vector (cross product with line_vec and perp_vec1) perp_vec2 = np.cross(line_vec, perp_vec1) perp_vec2 /= np.linalg.norm(perp_vec2) # Normalize - + # Scale the perpendicular vectors by tick length perp_vec1 *= tick_length perp_vec2 *= tick_length - + # Draw the cross (two perpendicular lines) - ax.plot([x_tick - perp_vec1[0], x_tick + perp_vec1[0]], - [y_tick - perp_vec1[1], y_tick + perp_vec1[1]], - [z_tick - perp_vec1[2], z_tick + perp_vec1[2]], color='gray', linewidth=0.8) - - ax.plot([x_tick - perp_vec2[0], x_tick + perp_vec2[0]], - [y_tick - perp_vec2[1], y_tick + perp_vec2[1]], - [z_tick - perp_vec2[2], z_tick + perp_vec2[2]], color='gray', linewidth=0.8) - - ax.text(x_tick, y_tick, z_tick, f'{int(100 - t * 100)}%', fontsize=6, ha='right', va='bottom') - + ax.plot( + [x_tick - perp_vec1[0], x_tick + perp_vec1[0]], + [y_tick - perp_vec1[1], y_tick + perp_vec1[1]], + [z_tick - perp_vec1[2], z_tick + perp_vec1[2]], + color="gray", + linewidth=0.8, + ) + + ax.plot( + [x_tick - perp_vec2[0], x_tick + perp_vec2[0]], + [y_tick - perp_vec2[1], y_tick + perp_vec2[1]], + [z_tick - perp_vec2[2], z_tick + perp_vec2[2]], + color="gray", + linewidth=0.8, + ) + + ax.text( + x_tick, + y_tick, + z_tick, + f"{int(100 - t * 100)}%", + fontsize=6, + ha="right", + va="bottom", + ) + + def add_perpendicular_ticks(x1, y1, x2, y2, tick_length=0.01): num_ticks = 6 # Number of ticks including 0% and 100% for i in range(num_ticks): t = i / (num_ticks - 1) x_tick = x1 + t * (x2 - x1) y_tick = y1 + t * (y2 - y1) - + # Vector along the line dx = x2 - x1 dy = y2 - y1 - + # Perpendicular vector perp_dx = -dy perp_dy = dx - + # Normalize the perpendicular vector length = np.sqrt(perp_dx**2 + perp_dy**2) perp_dx /= length perp_dy /= length - + # Draw tick as a short perpendicular line - plt.plot([x_tick - tick_length * perp_dx, x_tick + tick_length * perp_dx], [y_tick - tick_length * perp_dy, y_tick + tick_length * perp_dy], color='gray', linewidth=0.8) - plt.text(x_tick, y_tick, f'{int(100 - t * 100)}%', fontsize=8, ha='right', va='bottom') \ No newline at end of file + plt.plot( + [x_tick - tick_length * perp_dx, x_tick + tick_length * perp_dx], + [y_tick - tick_length * perp_dy, y_tick + tick_length * perp_dy], + color="gray", + linewidth=0.8, + ) + plt.text( + x_tick, + y_tick, + f"{int(100 - t * 100)}%", + fontsize=8, + ha="right", + va="bottom", + ) diff --git a/lyscripts/sample.py b/lyscripts/sample.py index 939ce13..ce2a77e 100644 --- a/lyscripts/sample.py +++ b/lyscripts/sample.py @@ -1,5 +1,4 @@ -""" -Learn the spread probabilities of the HMM for lymphatic tumor progression using +"""Learn the spread probabilities of the HMM for lymphatic tumor progression using the preprocessed data as input and MCMC as sampling method. This is the central script performing for our project on modelling lymphatic spread @@ -9,6 +8,7 @@ objective decisions with respect to defining the *elective clinical target volume* (CTV-N) in radiotherapy. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -58,60 +58,87 @@ def _add_arguments(parser: argparse.ArgumentParser): This is called by the parent module that is called via the command line. """ parser.add_argument( - "-i", "--input", type=Path, required=True, - help="Path to training data files" + "-i", "--input", type=Path, required=True, help="Path to training data files" ) parser.add_argument( - "-o", "--output", type=Path, required=True, - help="Path to the HDF5 file to store the results in" + "-o", + "--output", + type=Path, + required=True, + help="Path to the HDF5 file to store the results in", ) parser.add_argument( - "--history", type=Path, nargs="?", - help="Path to store the burnin history in (as CSV file)." + "--history", + type=Path, + nargs="?", + help="Path to store the burnin history in (as CSV file).", ) parser.add_argument( - "-w", "--walkers-per-dim", type=int, default=10, + "-w", + "--walkers-per-dim", + type=int, + default=10, help="Number of walkers per dimension", ) parser.add_argument( - "-b", "--burnin", type=int, nargs="?", - help="Number of burnin steps. If not provided, sampler runs until convergence." + "-b", + "--burnin", + type=int, + nargs="?", + help="Number of burnin steps. If not provided, sampler runs until convergence.", ) parser.add_argument( - "--check-interval", type=int, default=100, - help="Check convergence every `check_interval` steps." + "--check-interval", + type=int, + default=100, + help="Check convergence every `check_interval` steps.", ) parser.add_argument( - "--trust-fac", type=float, default=50., - help="Factor to trust the autocorrelation time for convergence." + "--trust-fac", + type=float, + default=50.0, + help="Factor to trust the autocorrelation time for convergence.", ) parser.add_argument( - "--rel-thresh", type=float, default=0.05, - help="Relative threshold for convergence." + "--rel-thresh", + type=float, + default=0.05, + help="Relative threshold for convergence.", ) parser.add_argument( - "-n", "--nsteps", type=int, default=100, - help="Number of MCMC samples to draw, irrespective of thinning." + "-n", + "--nsteps", + type=int, + default=100, + help="Number of MCMC samples to draw, irrespective of thinning.", ) parser.add_argument( - "-t", "--thin", type=int, default=10, - help="Thinning factor for the MCMC chain." + "-t", "--thin", type=int, default=10, help="Thinning factor for the MCMC chain." ) parser.add_argument( - "-p", "--params", default="./params.yaml", type=Path, - help="Path to parameter file." + "-p", + "--params", + default="./params.yaml", + type=Path, + help="Path to parameter file.", ) parser.add_argument( - "-c", "--cores", type=int, nargs="?", + "-c", + "--cores", + type=int, + nargs="?", help=( "Number of parallel workers (CPU cores/threads) to use. If not provided, " "it will use all cores. If set to zero, multiprocessing will not be used." - ) + ), ) parser.add_argument( - "-s", "--seed", type=int, default=42, - help="Seed value to reproduce the same sampling round." + "-s", + "--seed", + type=int, + default=42, + help="Seed value to reproduce the same sampling round.", ) parser.set_defaults(run_main=main) @@ -119,8 +146,9 @@ def _add_arguments(parser: argparse.ArgumentParser): MODEL = None + def log_prob_fn(theta: np.array) -> float: - """log probability function using global variables because of pickling.""" + """Log probability function using global variables because of pickling.""" return MODEL.likelihood(given_params=theta) @@ -183,11 +211,12 @@ def run_burnin( progress.update(task, advance=1) new_acor_time = sampler.get_autocorr_time(tol=0).mean() - old_acor_time = history.acor_times[-1] if len(history.acor_times) > 0 else np.inf + old_acor_time = ( + history.acor_times[-1] if len(history.acor_times) > 0 else np.inf + ) - new_accept_frac = ( - (np.sum(sampler.backend.accepted) - num_accepted) - / (sampler.nwalkers * check_interval) + new_accept_frac = (np.sum(sampler.backend.accepted) - num_accepted) / ( + sampler.nwalkers * check_interval ) num_accepted = np.sum(sampler.backend.accepted) @@ -198,7 +227,9 @@ def run_burnin( is_converged = burnin is None is_converged &= new_acor_time * trust_fac < sampler.iteration - is_converged &= np.abs(new_acor_time - old_acor_time) / new_acor_time < rel_thresh + is_converged &= ( + np.abs(new_acor_time - old_acor_time) / new_acor_time < rel_thresh + ) if is_converged: break @@ -236,6 +267,7 @@ def run_sampling( class DummyPool: """Dummy class to allow for no multiprocessing.""" + def __enter__(self): return None @@ -277,7 +309,9 @@ def main(args: argparse.Namespace) -> None: with real_or_dummy_pool as pool: sampler = emcee.EnsembleSampler( - nwalkers, ndim, log_prob_fn, + nwalkers, + ndim, + log_prob_fn, moves=moves_mix, backend=hdf5_backend, pool=pool, diff --git a/lyscripts/scenario.py b/lyscripts/scenario.py index f16a370..9955484 100644 --- a/lyscripts/scenario.py +++ b/lyscripts/scenario.py @@ -1,5 +1,4 @@ -""" -This module implements helpers and classes that help us deal with what we call a +"""This module implements helpers and classes that help us deal with what we call a *scenario*. A scenario is a set of parameters that determine how we compute priors, posteriors, prevalences, and risks. @@ -9,6 +8,7 @@ model) are relevant. But e.g. posteriors and risks also require us to provide a diagnosis, given which to compute the quantities of interest. """ + import argparse import hashlib import inspect @@ -33,8 +33,10 @@ class UninitializedProperty(Exception): in the getter when no private attribute is found. """ + ScenarioT = TypeVar("ScenarioT", bound="Scenario") + @dataclass class Scenario: """Dataclass for storing configuration of a scenario. @@ -52,7 +54,6 @@ class Scenario: is_uni: bool = False side: str = "ipsi" - @staticmethod def _defaults(property_name: str) -> Any: """Return the default value for a property. @@ -67,12 +68,11 @@ def _defaults(property_name: str) -> Any: {} """ return { - "t_stages_dist": np.array([1.]), + "t_stages_dist": np.array([1.0]), "involvement": {"ipsi": {}, "contra": {}}, "diagnosis": {"ipsi": {}, "contra": {}}, }[property_name] - def __post_init__(self) -> None: """Declate default value of properties. @@ -89,12 +89,11 @@ def __post_init__(self) -> None: if not self.is_uni: for side in ["ipsi", "contra"]: - if not side in self.diagnosis: + if side not in self.diagnosis: self.diagnosis[side] = {} - if not side in self.involvement: + if side not in self.involvement: self.involvement[side] = {} - @classmethod def fields(cls) -> dict[str, Any]: """Return a list of fields that may make up a scenario.""" @@ -132,11 +131,11 @@ def t_stages_dist(self) -> np.ndarray: self._t_stages_dist = self._defaults("t_stages_dist") if len(self._t_stages_dist) != len(self.t_stages): - new_x = np.linspace(0., 1., len(self.t_stages)) - old_x = np.linspace(0., 1., len(self._t_stages_dist)) + new_x = np.linspace(0.0, 1.0, len(self.t_stages)) + old_x = np.linspace(0.0, 1.0, len(self._t_stages_dist)) self._t_stages_dist = np.interp(new_x, old_x, self._t_stages_dist) - if not np.isclose(np.sum(self._t_stages_dist), 1.): + if not np.isclose(np.sum(self._t_stages_dist), 1.0): self._t_stages_dist /= np.sum(self._t_stages_dist) return np.array(self._t_stages_dist) @@ -146,7 +145,6 @@ def t_stages_dist(self, value: Iterable[float]) -> None: if not isinstance(value, property): self._t_stages_dist = value - @classmethod def from_namespace( cls, @@ -182,12 +180,16 @@ def from_namespace( scenario = cls(**kwargs) for side in ["ipsi", "contra"]: - pattern = getattr(namespace, f"{side}_involvement", None) or [None] * len(lnls) - tmp = {lnl: val for lnl, val in zip(lnls, pattern)} + pattern = getattr(namespace, f"{side}_involvement", None) or [None] * len( + lnls + ) + tmp = {lnl: val for lnl, val in zip(lnls, pattern, strict=False)} scenario._involvement[side] = tmp - pattern = getattr(namespace, f"{side}_diagnosis", None) or [None] * len(lnls) - tmp = {lnl: val for lnl, val in zip(lnls, pattern)} + pattern = getattr(namespace, f"{side}_diagnosis", None) or [None] * len( + lnls + ) + tmp = {lnl: val for lnl, val in zip(lnls, pattern, strict=False)} mod_name = getattr(namespace, "modality", "max_llh") scenario._diagnosis[side] = {mod_name: tmp} @@ -246,7 +248,6 @@ def list_from_params( return res - def as_dict( self, for_comp: Literal["priors", "posteriors", "prevalences", "risks"], @@ -260,21 +261,24 @@ def as_dict( if for_comp == "priors": return res - res.update({ - "midext": self.midext, - "diagnosis": self.diagnosis, - "side": self.side, - "is_uni": self.is_uni, - }) + res.update( + { + "midext": self.midext, + "diagnosis": self.diagnosis, + "side": self.side, + "is_uni": self.is_uni, + } + ) if for_comp == "risks": res["involvement"] = self.involvement return res - @property - def diagnosis(self) -> dict[str, dict[str, types.PatternType]] | dict[str, types.PatternType]: + def diagnosis( + self, + ) -> dict[str, dict[str, types.PatternType]] | dict[str, types.PatternType]: """Get bi- or unilateral diagosis, depending on attrs ``side`` and ``is_uni``.""" if not hasattr(self, "_diagnosis"): raise UninitializedProperty("diagnosis") @@ -289,7 +293,6 @@ def diagnosis(self, value: dict[str, dict[str, types.PatternType]]) -> None: if not isinstance(value, property): self._diagnosis = value - @property def involvement(self) -> dict[str, types.PatternType] | types.PatternType: """Get bi- or unilateral involvement, depending on attrs ``side`` and ``is_uni``.""" @@ -306,7 +309,6 @@ def involvement(self, value: dict[str, types.PatternType]) -> None: if not isinstance(value, property): self._involvement = value - def get_pattern( self, get_from: Literal["involvement", "diagnosis"], @@ -326,7 +328,6 @@ def get_pattern( return pattern - def md5_hash( self, for_comp: Literal["priors", "posteriors", "prevalences", "risks"], @@ -360,18 +361,24 @@ def add_scenario_arguments( {'mode': 'BN', 't_stages': ['early'], 't_stages_dist': array([1.])} """ parser.add_argument( - "--t-stages", nargs="+", default=["early"], + "--t-stages", + nargs="+", + default=["early"], help="T-stages to consider.", ) parser.add_argument( - "--t-stages-dist", nargs="+", type=float, + "--t-stages-dist", + nargs="+", + type=float, help=( "Distribution over T-stages. Prior distribution over hidden states will " "be marginalized over T-stages using this distribution." - ) + ), ) parser.add_argument( - "--mode", choices=["BN", "HMM"], default="HMM", + "--mode", + choices=["BN", "HMM"], + default="HMM", help="Mode to use for computing the scenario.", ) @@ -379,7 +386,9 @@ def add_scenario_arguments( return parser.add_argument( - "--midext", type=optional_bool, required=False, + "--midext", + type=optional_bool, + required=False, help=( "Use midline extention for computing the scenario. Only used with " "midline model." @@ -398,31 +407,41 @@ def add_scenario_arguments( if for_comp == "risks": parser.add_argument( - "--ipsi-involvement", nargs="+", type=optional_bool, + "--ipsi-involvement", + nargs="+", + type=optional_bool, help="Involvement to compute quantitty for (ipsilateral side).", ) parser.add_argument( - "--contra-involvement", nargs="+", type=optional_bool, + "--contra-involvement", + nargs="+", + type=optional_bool, help="Involvement to compute quantitty for (contralateral side).", ) if for_comp == "prevalences": parser.add_argument( - "--modality", default="max_llh", + "--modality", + default="max_llh", help="Modality name to compute predicted and observed prevalence for.", ) parser.add_argument( - "--ipsi-diagnosis", nargs="+", type=optional_bool, + "--ipsi-diagnosis", + nargs="+", + type=optional_bool, help="Diagnosis of ipsilateral side.", ) parser.add_argument( - "--contra-diagnosis", nargs="+", type=optional_bool, + "--contra-diagnosis", + nargs="+", + type=optional_bool, help="Diagnosis of contralateral side.", ) if __name__ == "__main__": - scenario = Scenario(t_stages=['a', 'b'], t_stages_dist=[0.2, 0.8]) + scenario = Scenario(t_stages=["a", "b"], t_stages_dist=[0.2, 0.8]) import doctest + doctest.testmod() diff --git a/lyscripts/temp_schedule.py b/lyscripts/temp_schedule.py index 09483e2..553c5b6 100644 --- a/lyscripts/temp_schedule.py +++ b/lyscripts/temp_schedule.py @@ -1,5 +1,4 @@ -""" -Generate inverse temperature schedules for thermodynamic integration using various +"""Generate inverse temperature schedules for thermodynamic integration using various different methods. Thermodynamic integration is quite sensitive to the specific schedule which is used. @@ -11,6 +10,7 @@ the interval $[0, 1]$ and then transform each point by computing $\\beta_i^k$ where $k$ could e.g. be 5. """ + # pylint: disable=logging-fstring-interpolation import argparse import logging @@ -27,9 +27,7 @@ def _add_parser( subparsers: argparse._SubParsersAction, help_formatter, ): - """ - Add an `ArgumentParser` to the subparsers action. - """ + """Add an `ArgumentParser` to the subparsers action.""" parser = subparsers.add_parser( Path(__file__).name.replace(".py", ""), description=__doc__, @@ -40,21 +38,26 @@ def _add_parser( def _add_arguments(parser: argparse.ArgumentParser): - """ - Add arguments needed to run this script to a `subparsers` instance + """Add arguments needed to run this script to a `subparsers` instance and run the respective main function when chosen. """ parser.add_argument( - "--method", choices=SCHEDULES.keys(), default=list(SCHEDULES.keys())[0], - help="Choose the method to distribute the inverse temperature." + "--method", + choices=SCHEDULES.keys(), + default=list(SCHEDULES.keys())[0], + help="Choose the method to distribute the inverse temperature.", ) parser.add_argument( - "--num", default=32, type=int, - help="Number of inverse temperatures in the schedule" + "--num", + default=32, + type=int, + help="Number of inverse temperatures in the schedule", ) parser.add_argument( - "--pow", default=4, type=float, - help="If a power schedule is chosen, use this as power" + "--pow", + default=4, + type=float, + help="If a power schedule is chosen, use this as power", ) parser.set_defaults(run_main=main) @@ -62,36 +65,39 @@ def _add_arguments(parser: argparse.ArgumentParser): def tolist(func: Callable) -> Callable: """Decorator to make sure the returned value is a list of floats.""" + def inner(*args) -> np.ndarray | list[float]: res = func(*args) if isinstance(res, np.ndarray): return res.tolist() return res + return inner @tolist def geometric_schedule(n: int, *_a) -> np.ndarray: """Create a geometric sequence of `n` numbers from 0. to 1.""" - log_seq = np.logspace(0., 1., n) - shifted_seq = log_seq - 1. - geom_seq = shifted_seq / 9. + log_seq = np.logspace(0.0, 1.0, n) + shifted_seq = log_seq - 1.0 + geom_seq = shifted_seq / 9.0 return geom_seq @tolist def linear_schedule(n: int, *_a) -> np.ndarray: """Create a linear sequence of `n` numbers from 0. to 1.""" - return np.linspace(0., 1., n) + return np.linspace(0.0, 1.0, n) @tolist def power_schedule(n: int, power: float, *_a) -> np.ndarray: """Create a power sequence of `n` numbers from 0. to 1.""" - lin_seq = np.linspace(0., 1., n) + lin_seq = np.linspace(0.0, 1.0, n) power_seq = lin_seq**power return power_seq + SCHEDULES = { "geometric": geometric_schedule, "linear": linear_schedule, diff --git a/lyscripts/utils.py b/lyscripts/utils.py index eaeee49..cd43f2e 100644 --- a/lyscripts/utils.py +++ b/lyscripts/utils.py @@ -1,11 +1,11 @@ -""" -This module contains frequently used functions and decorators that are used throughout +"""This module contains frequently used functions and decorators that are used throughout the subcommands to load e.g. YAML specifications or model definitions. It also contains helpers for reporting the script's progress via a slightly customized `rich` console and a custom `Exception` called `LyScriptsWarning` that can propagate occuring issues to the right place. """ + import warnings from logging import LogRecord from pathlib import Path @@ -16,8 +16,8 @@ import yaml from deprecated import deprecated from emcee.backends import HDFBackend -from lymph import diagnosis_times, models, types from lymixture import LymphMixture +from lymph import diagnosis_times, models, types from rich.console import Console from rich.logging import RichHandler from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn @@ -32,8 +32,10 @@ try: import streamlit from streamlit.runtime.scriptrunner import get_script_run_ctx + streamlit.status = streamlit.spinner except ImportError: + def get_script_run_ctx() -> bool: """A mock for the `get_script_run_ctx` function of `streamlit`.""" return None @@ -49,13 +51,13 @@ def get_script_run_ctx() -> bool: class LyScriptsWarning(Warning): - """ - Exception that can be raised by methods if they want the `LyScriptsReport` instance + """Exception that can be raised by methods if they want the `LyScriptsReport` instance to not stop and print a traceback, but display some message appropriately. Essentially, this is a way for decorated functions to propagate messages through the `report_state` decorator. """ + def __init__(self, *args: object, level: str = "info") -> None: """Extract the `level` of the message (can be "info", "warning" or "error").""" self.level = level @@ -70,7 +72,8 @@ def is_streamlit_running() -> bool: class CustomProgress(Progress): """Small wrapper around rich's `Progress` initializing my custom columns.""" - def __init__( self, **kwargs: dict): + + def __init__(self, **kwargs: dict): columns = [ SpinnerColumn(finished_text=CHECK), *Progress.get_default_columns(), @@ -81,6 +84,7 @@ def __init__( self, **kwargs: dict): class CustomRichHandler(RichHandler): """Uses `func_filepath` from the `extra` dict to modify `pathname`.""" + def emit(self, record: LogRecord) -> None: """Emit a log record.""" if ( @@ -100,11 +104,13 @@ def emit(self, record: LogRecord) -> None: def binom_pmf(support: list[int] | np.ndarray, p: float = 0.5): """Binomial PMF""" max_time = len(support) - 1 - if p > 1. or p < 0.: + if p > 1.0 or p < 0.0: raise ValueError("Binomial prob must be btw. 0 and 1") - q = 1. - p - binom_coeff = factorial(max_time) / (factorial(support) * factorial(max_time - support)) - return binom_coeff * p**support * q**(max_time - support) + q = 1.0 - p + binom_coeff = factorial(max_time) / ( + factorial(support) * factorial(max_time - support) + ) + return binom_coeff * p**support * q ** (max_time - support) FUNCS = { @@ -116,7 +122,7 @@ def graph_from_config(graph_params: dict) -> dict[tuple[str, str], list[str]]: """Build graph dictionary for the `lymph` models from the YAML params.""" lymph_graph = {} - if not "tumor" in graph_params and "lnl" in graph_params: + if "tumor" not in graph_params and "lnl" in graph_params: raise KeyError("Parameters must define tumors and LNLs") for node_type, node_dict in graph_params.items(): @@ -151,7 +157,7 @@ def _create_model_from_v0(params: dict[str, Any]) -> types.Model: if "model" in params: model_cls = getattr(models, params["model"]["class"]) - if not "is_symmetric" in params["model"]["kwargs"]: + if "is_symmetric" not in params["model"]["kwargs"]: warnings.warn( "The keywords `base_symmetric`, `trans_symmetric`, and `use_mixing` " "have been deprecated. Please use `is_symmetric` instead.", @@ -201,7 +207,7 @@ def assign_modalities( to the ``model``. Example: - + ------- >>> from_config = { ... "CT": {"spec": 0.76, "sens": 0.81}, ... "MRI": [0.63, 0.86, "pathological"], @@ -218,6 +224,7 @@ def assign_modalities( >>> assign_modalities(model, from_config, subset=["CT"]) >>> model.get_all_modalities() # doctest: +NORMALIZE_WHITESPACE {'CT': Clinical(spec=0.76, sens=0.81, is_trinary=False)} + """ if clear: model.clear_modalities() @@ -242,7 +249,7 @@ def create_distribution(config: dict[str, Any]) -> diagnosis_times.Distribution: kwargs = config.get("kwargs", {}) if (type_ := config.get("frozen")) is not None: - kwargs.update({"support": np.arange(max_time+1)}) + kwargs.update({"support": np.arange(max_time + 1)}) distribution = diagnosis_times.Distribution(FUNCS[type_](**kwargs)) elif (type_ := config.get("parametric")) is not None: distribution = diagnosis_times.Distribution(FUNCS[type_], max_time, **kwargs) @@ -294,28 +301,38 @@ def create_mixture(config: dict[str, Any], config_version: int = 0) -> types.Mod raise LyScriptsWarning("No graph definition found in YAML file", level="error") if (model_config := config.get("model")) is None: - raise LyScriptsWarning("No mixture definition found in YAML file", level="error") + raise LyScriptsWarning( + "No mixture definition found in YAML file", level="error" + ) graph_dict = graph_from_config(graph_config) model_cls_name, _, cls_meth_name = model_config["class"].partition(".") - if model_cls_name != 'Unilateral': - raise LyScriptsWarning("The mixture model has only been implemented for Unilateral so far", level = "error") + if model_cls_name != "Unilateral": + raise LyScriptsWarning( + "The mixture model has only been implemented for Unilateral so far", + level="error", + ) model_cls = getattr(models, model_cls_name) model_kwargs = model_config.get("kwargs", {}) - if not isinstance(model_kwargs,dict): + if not isinstance(model_kwargs, dict): model_kwargs = {} - model_num_components = model_config.get('num_components') - model_kwargs['graph_dict'] = graph_dict - mixture = LymphMixture(model_cls = model_cls, model_kwargs = model_kwargs, num_components = model_num_components) + model_num_components = model_config.get("num_components") + model_kwargs["graph_dict"] = graph_dict + mixture = LymphMixture( + model_cls=model_cls, + model_kwargs=model_kwargs, + num_components=model_num_components, + ) + + # note: modalities can't be set here, as we need to add the data first to define the number of subgroups - #note: modalities can't be set here, as we need to add the data first to define the number of subgroups - for t_stage, dist_config in model_config.get("distributions", {}).items(): distribution = create_distribution(dist_config) mixture.set_distribution(t_stage, distribution) return mixture + def get_dict_depth(nested: dict) -> int: """Get the depth of a nested dictionary. @@ -437,7 +454,7 @@ def load_patient_data( ) -> pd.DataFrame: """Load patient data from a CSV file stored at ``file``.""" if header is None: - header = [0,1,2] + header = [0, 1, 2] return pd.read_csv(file_path, header=header) @@ -491,6 +508,7 @@ def initialize_backend( FalseChoices = Literal["false", "f", "no", "n", "healthy", "benign"] """Type alias for what is interpreted as healthy/benign involvement of an LNL.""" + def optional_bool(value: NoneChoices | TrueChoices | FalseChoices) -> bool | None: """Convert a string to a boolean or ``None``. @@ -514,8 +532,8 @@ def make_pattern( lnls: list[str], ) -> dict[str, bool | None]: """Create a dictionary from a list of bools and Nones.""" - return dict(zip(lnls, from_list or [None] * len(lnls))) + return dict(zip(lnls, from_list or [None] * len(lnls), strict=False)) def to_numpy(params: dict[str, float]) -> np.ndarray: - return np.array([p for p in params.values()]) \ No newline at end of file + return np.array([p for p in params.values()]) diff --git a/tests/data/join_test.py b/tests/data/join_test.py index 879c96f..628cfa9 100644 --- a/tests/data/join_test.py +++ b/tests/data/join_test.py @@ -1,6 +1,5 @@ -""" -Test the correct joining of datasets. -""" +"""Test the correct joining of datasets.""" + from pathlib import Path from lyscripts.data.join import load_and_join_tables @@ -8,12 +7,11 @@ def test_load_and_join_tables(): """Test the correct joining of datasets.""" - input_paths = [Path("./tests/data/b.csv"), Path("./tests/data/a.csv")] joined = load_and_join_tables(input_paths) - assert joined.shape == (13,3), "Wrong concatenation shape." - assert joined['x','z','b'].isna().sum() == 6, "Wrong number of NaNs." - assert (joined['x','z','b'] == True).sum() == 4, "Wrong number of True values." - assert (joined['x','z','b'] == False).sum() == 3, "Wrong number of False values." - assert ('x', 'z', 'c') in joined.columns, "Column 'c' not found in joined table." + assert joined.shape == (13, 3), "Wrong concatenation shape." + assert joined["x", "z", "b"].isna().sum() == 6, "Wrong number of NaNs." + assert (joined["x", "z", "b"] == True).sum() == 4, "Wrong number of True values." + assert (joined["x", "z", "b"] == False).sum() == 3, "Wrong number of False values." + assert ("x", "z", "c") in joined.columns, "Column 'c' not found in joined table." diff --git a/tests/plot/plot_utils_test.py b/tests/plot/plot_utils_test.py index e1794f7..d5b350f 100644 --- a/tests/plot/plot_utils_test.py +++ b/tests/plot/plot_utils_test.py @@ -1,6 +1,5 @@ -""" -Testing of the utilities implemented for the plotting routines. -""" +"""Testing of the utilities implemented for the plotting routines.""" + from pathlib import Path import matplotlib.pyplot as plt @@ -21,20 +20,18 @@ @pytest.fixture def beta_samples(): - """ - Filename of an HDF5 file where some samples from a Beta distribution are stored - """ + """Filename of an HDF5 file where some samples from a Beta distribution are stored""" return "./tests/plot/data/beta_samples.hdf5" def test_floor_to_step(): """Check correct rounding down to a given step size.""" - numbers = np.array([0., 3., 7.4, 2.01, np.pi, 12.7, 12.7, 17.3 ]) - steps = np.array([2 , 2 , 5 , 2 , 3 , 3 , 5 , 0.17 ]) - exp_res = np.array([0., 2., 5. , 2. , 3. , 12. , 10. , 17.17]) + numbers = np.array([0.0, 3.0, 7.4, 2.01, np.pi, 12.7, 12.7, 17.3]) + steps = np.array([2, 2, 5, 2, 3, 3, 5, 0.17]) + exp_res = np.array([0.0, 2.0, 5.0, 2.0, 3.0, 12.0, 10.0, 17.17]) comp_res = np.zeros_like(exp_res) - for i, (num, step) in enumerate(zip(numbers, steps)): + for i, (num, step) in enumerate(zip(numbers, steps, strict=False)): comp_res[i] = floor_to_step(num, step) assert all(np.isclose(comp_res, exp_res)), "Floor to step did not work properly." @@ -42,12 +39,12 @@ def test_floor_to_step(): def test_ceil_to_step(): """Check correct rounding up to a given step size.""" - numbers = np.array([0., 3., 7.4, 2.01, np.pi, 12.7, 12.7, 17.3 ]) - steps = np.array([2 , 2 , 5 , 2 , 3 , 3 , 5 , 0.17 ]) - exp_res = np.array([2., 4., 10., 4. , 6. , 15. , 15. , 17.34]) + numbers = np.array([0.0, 3.0, 7.4, 2.01, np.pi, 12.7, 12.7, 17.3]) + steps = np.array([2, 2, 5, 2, 3, 3, 5, 0.17]) + exp_res = np.array([2.0, 4.0, 10.0, 4.0, 6.0, 15.0, 15.0, 17.34]) comp_res = np.zeros_like(exp_res) - for i, (num, step) in enumerate(zip(numbers, steps)): + for i, (num, step) in enumerate(zip(numbers, steps, strict=False)): comp_res[i] = ceil_to_step(num, step) assert all(np.isclose(comp_res, exp_res)), "Ceil to step did not work properly." @@ -64,30 +61,34 @@ def test_histogram_cls(beta_samples): hist_from_path = Histogram.from_hdf5( filename=path_filename, dataname="beta", - scale=10., + scale=10.0, label=custom_label, ) with pytest.raises(FileNotFoundError): Histogram.from_hdf5(filename=non_existent_filename, dataname="does_not_matter") - assert np.all(np.isclose(hist_from_str.values, 10. * hist_from_path.values)), ( - "Scaling of data does not work correclty" - ) - assert np.all(np.isclose( - hist_from_str.left_percentile(50.), - hist_from_str.right_percentile(50.), - )), "50% percentiles should be the same from the left and from the right." - assert np.all(np.isclose( - hist_from_path.left_percentile(10.), - hist_from_path.right_percentile(90.), - )), "10% from the left is not the same as 90% from the right" - assert hist_from_str.kwargs["label"] == "beta | mega scan | 100 | ext", ( - "Label extraction did not work" - ) - assert hist_from_path.kwargs["label"] == custom_label, ( - "Keyword override did not work" - ) + assert np.all( + np.isclose(hist_from_str.values, 10.0 * hist_from_path.values) + ), "Scaling of data does not work correclty" + assert np.all( + np.isclose( + hist_from_str.left_percentile(50.0), + hist_from_str.right_percentile(50.0), + ) + ), "50% percentiles should be the same from the left and from the right." + assert np.all( + np.isclose( + hist_from_path.left_percentile(10.0), + hist_from_path.right_percentile(90.0), + ) + ), "10% from the left is not the same as 90% from the right" + assert ( + hist_from_str.kwargs["label"] == "beta | mega scan | 100 | ext" + ), "Label extraction did not work" + assert ( + hist_from_path.kwargs["label"] == custom_label + ), "Keyword override did not work" def test_inverted_histogram_cls(beta_samples): @@ -100,28 +101,32 @@ def test_inverted_histogram_cls(beta_samples): hist_from_path = Histogram.from_hdf5( filename=path_filename, dataname="beta", - scale=-100., - offset=100., + scale=-100.0, + offset=100.0, label=custom_label, ) - assert np.all(np.isclose(100. - hist_from_str.values, hist_from_path.values)), ( - "Scaling and offsetting of data does not work correclty" - ) - assert np.all(np.isclose( - hist_from_str.left_percentile(50.), - hist_from_str.right_percentile(50.), - )), "50% percentiles should be the same from the left and from the right." - assert np.all(np.isclose( - hist_from_path.left_percentile(10.), - hist_from_path.right_percentile(90.), - )), "10% from the left is not the same as 90% from the right" - assert hist_from_str.kwargs["label"] == "beta | mega scan | 100 | ext", ( - "Label extraction did not work" - ) - assert hist_from_path.kwargs["label"] == custom_label, ( - "Keyword override did not work" - ) + assert np.all( + np.isclose(100.0 - hist_from_str.values, hist_from_path.values) + ), "Scaling and offsetting of data does not work correclty" + assert np.all( + np.isclose( + hist_from_str.left_percentile(50.0), + hist_from_str.right_percentile(50.0), + ) + ), "50% percentiles should be the same from the left and from the right." + assert np.all( + np.isclose( + hist_from_path.left_percentile(10.0), + hist_from_path.right_percentile(90.0), + ) + ), "10% from the left is not the same as 90% from the right" + assert ( + hist_from_str.kwargs["label"] == "beta | mega scan | 100 | ext" + ), "Label extraction did not work" + assert ( + hist_from_path.kwargs["label"] == custom_label + ), "Keyword override did not work" def test_posterior_cls(beta_samples): @@ -130,41 +135,49 @@ def test_posterior_cls(beta_samples): path_filename = Path(str_filename) non_existent_filename = "non_existent.hdf5" custom_label = "Lorem ipsum" - x_10 = np.linspace(0., 10., 100) - x_100 = np.linspace(0., 100., 100) + x_10 = np.linspace(0.0, 10.0, 100) + x_100 = np.linspace(0.0, 100.0, 100) post_from_str = BetaPosterior.from_hdf5(filename=str_filename, dataname="beta") post_from_path = BetaPosterior.from_hdf5( filename=path_filename, dataname="beta", - scale=10., + scale=10.0, label=custom_label, ) with pytest.raises(FileNotFoundError): - BetaPosterior.from_hdf5(filename=non_existent_filename, dataname="does_not_matter") - - assert post_from_str.num_success == post_from_path.num_success == 20, ( - "Number of successes not correctly extracted" - ) - assert post_from_str.num_total == post_from_path.num_total == 40, ( - "Total number of trials not correctly extracted" - ) - assert post_from_str.num_fail == post_from_path.num_fail == 20, ( - "Number of failures not correctly computed" - ) - assert np.all(np.isclose( - 10 * post_from_str.pdf(x_100), - post_from_path.pdf(x_10), - )), "PDFs with different scaling do not match" - assert np.all(np.isclose( - post_from_str.left_percentile(50.), - post_from_str.right_percentile(50.), - )), "50% percentiles should be the same from the left and from the right." - assert np.all(np.isclose( - post_from_path.left_percentile(10.), - post_from_path.right_percentile(90.), - )), "10% from the left is not the same as 90% from the right" + BetaPosterior.from_hdf5( + filename=non_existent_filename, dataname="does_not_matter" + ) + + assert ( + post_from_str.num_success == post_from_path.num_success == 20 + ), "Number of successes not correctly extracted" + assert ( + post_from_str.num_total == post_from_path.num_total == 40 + ), "Total number of trials not correctly extracted" + assert ( + post_from_str.num_fail == post_from_path.num_fail == 20 + ), "Number of failures not correctly computed" + assert np.all( + np.isclose( + 10 * post_from_str.pdf(x_100), + post_from_path.pdf(x_10), + ) + ), "PDFs with different scaling do not match" + assert np.all( + np.isclose( + post_from_str.left_percentile(50.0), + post_from_str.right_percentile(50.0), + ) + ), "50% percentiles should be the same from the left and from the right." + assert np.all( + np.isclose( + post_from_path.left_percentile(10.0), + post_from_path.right_percentile(90.0), + ) + ), "10% from the left is not the same as 90% from the right" @pytest.mark.mpl_image_compare @@ -175,7 +188,7 @@ def test_draw(beta_samples): hist = Histogram.from_hdf5(filename, dataname) post = BetaPosterior.from_hdf5(filename, dataname) fig, ax = plt.subplots() - ax = draw(axes=ax, contents=[hist, post], percent_lims=(2.,2.)) + ax = draw(axes=ax, contents=[hist, post], percent_lims=(2.0, 2.0)) return fig @@ -197,7 +210,9 @@ def test_draw_hist_kwargs(beta_samples): global_kwargs_path = "./tests/plot/results/global_kwargs" fig, global_kwargs_ax = plt.subplots() - global_kwargs_ax = draw(global_kwargs_ax, contents=[hist], hist_kwargs={"alpha": 0.3}) + global_kwargs_ax = draw( + global_kwargs_ax, contents=[hist], hist_kwargs={"alpha": 0.3} + ) save_figure(global_kwargs_path, fig, ["png"]) hist = Histogram.from_hdf5(filename, dataname, alpha=0.3) @@ -206,45 +221,55 @@ def test_draw_hist_kwargs(beta_samples): local_kwargs_ax = draw(local_kwargs_ax, contents=[hist], hist_kwargs={"alpha": 1.0}) save_figure(local_kwargs_path, fig, ["png"]) - assert mpl_comp.compare_images( - expected=default_kwargs_path + ".png", - actual=bins_kwargs_path + ".png", - tol=0.001, - ) is not None, "Changing bin number did not result in different plot" - - assert mpl_comp.compare_images( - expected=default_kwargs_path + ".png", - actual=global_kwargs_path + ".png", - tol=0.001, - ) is not None, "Changing global kwargs in `draw` did not result in different plot" - - assert mpl_comp.compare_images( - expected=local_kwargs_path + ".png", - actual=global_kwargs_path + ".png", - tol=0.001, - ) is None, "Overriding global with `Histogram` specific kwargs did not work" + assert ( + mpl_comp.compare_images( + expected=default_kwargs_path + ".png", + actual=bins_kwargs_path + ".png", + tol=0.001, + ) + is not None + ), "Changing bin number did not result in different plot" + + assert ( + mpl_comp.compare_images( + expected=default_kwargs_path + ".png", + actual=global_kwargs_path + ".png", + tol=0.001, + ) + is not None + ), "Changing global kwargs in `draw` did not result in different plot" + + assert ( + mpl_comp.compare_images( + expected=local_kwargs_path + ".png", + actual=global_kwargs_path + ".png", + tol=0.001, + ) + is None + ), "Overriding global with `Histogram` specific kwargs did not work" def test_save_figure(capsys): """Check that figures get stored correctly.""" - x = np.linspace(0., 2*np.pi, 200) + x = np.linspace(0.0, 2 * np.pi, 200) y = np.sin(x) fig, ax = plt.subplots(figsize=get_size()) - ax.plot(x,y) + ax.plot(x, y) output_path = "./tests/plot/results/sine" formats = ["png", "svg"] - expected_output = ( - "✓ Saved matplotlib figure.\n" - ) + expected_output = "✓ Saved matplotlib figure.\n" save_figure(output_path, fig, formats) save_figure_capture = capsys.readouterr() - assert mpl_comp.compare_images( - expected="./tests/plot/baseline/sine.png", - actual="./tests/plot/results/sine.png", - tol=0., - ) is None, "PNG of figure was not stored correctly." + assert ( + mpl_comp.compare_images( + expected="./tests/plot/baseline/sine.png", + actual="./tests/plot/results/sine.png", + tol=0.0, + ) + is None + ), "PNG of figure was not stored correctly." # Commented out, because I recently got the following message from matplotlib: # `SKIPPED (Don't know how to convert .svg files to png)` diff --git a/tests/predict/predict_utils_test.py b/tests/predict/predict_utils_test.py index a578f6f..06fb7c7 100644 --- a/tests/predict/predict_utils_test.py +++ b/tests/predict/predict_utils_test.py @@ -1,12 +1,10 @@ -""" -Test utilities of the predict submodule. -""" +"""Test utilities of the predict submodule.""" + from lyscripts.compute.utils import complete_pattern def test_clean_pattern(): - """ - Test the utility function that cleans the involvement patterns from the + """Test the utility function that cleans the involvement patterns from the `params.yaml` file """ empty_pattern = {} @@ -20,13 +18,13 @@ def test_clean_pattern(): assert empty_cleaned == { "ipsi": {"I": None, "II": None, "III": None}, - "contra": {"I": None, "II": None, "III": None} + "contra": {"I": None, "II": None, "III": None}, }, "Empty pattern does not get filled correctly." assert one_pos_cleaned == { "ipsi": {"I": None, "II": True, "III": None}, - "contra": {"I": None, "II": None, "III": None} + "contra": {"I": None, "II": None, "III": None}, }, "Pattern with one positive LNL not cleaned properly." assert nums_cleaned == { "ipsi": {"I": True, "II": None, "III": None}, - "contra": {"I": None, "II": None, "III": False} + "contra": {"I": None, "II": None, "III": False}, }, "Number pattern cleaned wrongly." diff --git a/tests/predict/prevalences_test.py b/tests/predict/prevalences_test.py index 4ae9b63..19ac051 100644 --- a/tests/predict/prevalences_test.py +++ b/tests/predict/prevalences_test.py @@ -1,6 +1,5 @@ -""" -Test the functions of the prevalence prediction submodule. -""" +"""Test the functions of the prevalence prediction submodule.""" + import pandas as pd import pytest @@ -12,16 +11,16 @@ def test_get_match_idx(): """Test if the pattern dictionaries & pandas data are compared correctly.""" oneside_pattern = {"I": False, "II": True, "III": None} ignorant_pattern = {"I": None, "II": None, "III": None} - three_patients = pd.DataFrame.from_dict({ - "I": [False, False, True], - "II": [True , True , True], - "III": [True , False, True], - }) + three_patients = pd.DataFrame.from_dict( + { + "I": [False, False, True], + "II": [True, True, True], + "III": [True, False, True], + } + ) lnls = list(oneside_pattern.keys()) - matching_idxs = get_match_idx( - True, oneside_pattern, three_patients, invert=False - ) + matching_idxs = get_match_idx(True, oneside_pattern, three_patients, invert=False) inverted_matching_idxs = get_match_idx( False, oneside_pattern, three_patients, invert=True ) @@ -41,19 +40,16 @@ def test_get_match_idx(): def test_does_midline_ext_match(): - """ - Test the function that returns indices of a `DataFrame` where the midline + """Test the function that returns indices of a `DataFrame` where the midline extension matches. """ - midline_data = pd.DataFrame({ - ("tumor", "1", "extension"): [True, False, None] - }) + midline_data = pd.DataFrame({("tumor", "1", "extension"): [True, False, None]}) keyerr_data = pd.DataFrame({("way", "too", "many", "levels"): [True, False, None]}) midline_match = does_midext_match(midline_data, midext=False) - assert all(midline_match == pd.Series([False, True, False])), ( - "Matching midline extension with correct data does not work." - ) + assert all( + midline_match == pd.Series([False, True, False]) + ), "Matching midline extension with correct data does not work." with pytest.raises(KeyError): _ = does_midext_match(keyerr_data, midext=False) diff --git a/tests/run_doctests.py b/tests/run_doctests.py index dd603fc..3726fad 100644 --- a/tests/run_doctests.py +++ b/tests/run_doctests.py @@ -1,6 +1,5 @@ -""" -Script to run doctests in the modules of `lyscripts`. -""" +"""Script to run doctests in the modules of `lyscripts`.""" + import doctest from lyscripts import utils diff --git a/tests/sample_test.py b/tests/sample_test.py index f526c65..42573f2 100644 --- a/tests/sample_test.py +++ b/tests/sample_test.py @@ -1,11 +1,11 @@ -""" -Test the sampling command with some example patients. +"""Test the sampling command with some example patients. Originally, I wanted to test that the sampling procedure is reproducible, but the `emcee` package does not seem to work with any kind of seed in a reproducible manner. Maybe I am doing something wrong... """ + # pylint: disable=redefined-outer-name import numpy as np import pytest @@ -66,10 +66,14 @@ def test_burnin(sampler: EnsembleSampler): assert sampler.iteration == 100, "Burnin di not run 100 iterations." assert len(burnin_history.steps) == 10, "Burnin history does not have 10 entries." assert np.all( - np.array([ - 0.7147557514447068, 0.9227188150264771, - 0.2629624184410706, 0.6001184115584288, - ]) + np.array( + [ + 0.7147557514447068, + 0.9227188150264771, + 0.2629624184410706, + 0.6001184115584288, + ] + ) == sampler.get_last_sample().coords[0] ), "Not reproducible." @@ -87,6 +91,5 @@ def test_get_starting_state(sampler: EnsembleSampler): check_interval=2, ) assert np.all( - get_starting_state(sampler).coords - == sampler.get_last_sample().coords + get_starting_state(sampler).coords == sampler.get_last_sample().coords ), "State is not the same." diff --git a/tests/utils_test.py b/tests/utils_test.py index fd6c510..c60bcba 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -1,6 +1,5 @@ -""" -Test the core utility functions of the package. -""" +"""Test the core utility functions of the package.""" + import pytest from lymph import models @@ -66,9 +65,13 @@ def test_create_model(params_v1): assert not model.use_central, "Model should not use central" assert model.use_midext_evo, "Model should use midext evolution" assert "early" in model.get_all_distributions(), "Early distribution is missing" - assert not model.get_distribution("early").is_updateable, "Early distribution should not be updateable" + assert not model.get_distribution( + "early" + ).is_updateable, "Early distribution should not be updateable" assert "late" in model.get_all_distributions(), "Late distribution is missing" - assert model.get_distribution("late").is_updateable, "Late distribution should be updateable" + assert model.get_distribution( + "late" + ).is_updateable, "Late distribution should be updateable" assert "CT" in model.get_all_modalities(), "CT modality is missing" assert model.get_modality("CT").spec == 0.76, "CT modality has wrong specificity" assert model.get_modality("CT").sens == 0.81, "CT modality has wrong sensitivity"