Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Per-parameter Learning Rate (PPLR) Implementation, Tensor Decomposition Models, and Tomography Updates#213

Open
cedriclim1 wants to merge 30 commits intodevelectronmicroscopy/quantem:devfrom
optmixin_pplrelectronmicroscopy/quantem:optmixin_pplrCopy head branch name to clipboard
Open

Per-parameter Learning Rate (PPLR) Implementation, Tensor Decomposition Models, and Tomography Updates#213
cedriclim1 wants to merge 30 commits intodevelectronmicroscopy/quantem:devfrom
optmixin_pplrelectronmicroscopy/quantem:optmixin_pplrCopy head branch name to clipboard

Conversation

@cedriclim1
Copy link
Copy Markdown
Collaborator

@cedriclim1 cedriclim1 commented Apr 24, 2026

What does this PR do?

First pass implementation of KPlanes implementation (https://brentyi.github.io/tilted/, https://sarafridov.github.io/K-Planes/). These tensor decomposition models require multiple parmeter optimization which is not handled in OptimizerMixin.

The solution @arthurmccray and I came up with is to setup a per-parameter learning rate PPLR abstract class which will be instantiated with models that have multiple parameters to be optimized to. This is implemented incore/ml/models/model_base.py. The most important part about this class is the get_params method, which is used for parsing the trainable parameters (see ObjectTensorDecomp's get_optimization_parameters) .

class PPLR(ABC):
    """
    Abstract base class for models that require multi-scale parameter optimization.
    """

    @abstractmethod
    def get_params(self) -> Dict[str, list[nn.Parameter]]:
        """
        Return a dictionary of parameters grouped by key.

        For example if your nn.Module has multiple optimizable parameter groups,
        you can return a dictionary with the keys "grids" and "sigma_net"
        (KPlanes example).
        """
        pass

The KPlanes and TILTED implementations are in core/ml/models in kplanes.py with KPlanes, KPlanesTILTED, and CPTilted. Note CPTIlted is used for the two-phase warmup that the paper recommends by pretraining the SO3 rotations of the planes in a lower representation space. so3params.py also has both quaternion and R9+SVD implementations which can be swapped as a parameter in KPlanesTILTED.

In tomography/object_models.py there is now a new object that handles the new tensor decomposition methods, ObjectTensorDecomp that inherits from ObjectINR and the following overloads are needed:

There are a few optional overrides that I've also included preemptively just incase if I want to add additional soft/hard constraints to these methods as you can see both apply_soft_constraints and apply_hard_constraints are still there.

Key change to OptimizerMixin: I have changed set_optimizer to handle a dictionary of parameters. I think this is better than my original idea of just implementing this in object_models.py since you would have to overload that function again. I think this is fine and still works for single optimization stuff.

There are also small changes to different aspects of the Tomography module, but mostly contained in object_models.py. Listing out the changes and rationale in the tomography files:

  • tomography.py: I've disabled torch.bfloat16 autocasting since F.grid_sample does not have autograd implemented errors. I've also put the TILTED tensor decomp pretraining check by looking at the convergence of the SO3 rotations. I don't like this being here, but I'm not quite sure where else to put this.
  • tomography_base.py: Type-hinting change for setting up distributed.
  • tomography_opt.py: Parsing is already handled in object_models.py.

What the reviewer should do

Attached is a notebook that tests tensor decomposition methods in tomography: 0415_PPLR_testing.ipynb. The dataset here is from quantem-tutorials, so just set the directory of the same dataset to this notebook. The runtime can be increased by increasing the batch size, but the learning rates have to be scaled. Setting the batch size to 4096 seems to be fine? So feel free to change it and scale the learning rates by 2.

  • Double check implementations across the board, especially the optimizer_mixin.py, and object_models.py
  • Organization makes sense for tensor decomposition stuff. I'm actually not very happy with how core/ml is looking. Where do we draw the line of what lives in models (I think everything model related should live in there tbh: INR, CNN, etc...)?
  • Check reconstruction workflows are still working i.e, ptychography due to the OptimizerMixin change.

cedriclim1 and others added 21 commits April 15, 2026 17:06
…eeds to be overloaded is set_optimizer for PPLR cases
…to do the matching in set_optimizer instead of parsing in optimizer_params maybe?
… to check: Look at object_models.py and see how the optimizer matching should be handled. It seems like set_optimizers doesn't really do what it's supposed to do.
… probably have to do TV loss computation within the model?
…well. Only things to ask Corneel about is multiscale res since this adds a significant amount of compute. Should I be doing variable num_samples_per_ray?
…t DDP, clean-up KPlanes, fix up object_models.py since it's insanely cluttered now
…ositionModel ABC, make sure to have a property for which kind of tensor decomposition method is being used. SO3Params are moved to a different file, thinking of making a kplanes_utils.py. Starting reorganization of object_models.py to have ObjectINR and ObjectTensorDecomp
… of parameters now that helps with type-setting. The main reason for having model_base.py as is it is right now is if we ever wanted to go do TensoRF or something just to validate
@cedriclim1 cedriclim1 requested a review from arthurmccray April 24, 2026 00:58
@bobleesj
Copy link
Copy Markdown
Collaborator

@cedriclim1 the default VS code setting has been removed, so please pull from upstreadm/dev

@cedriclim1 cedriclim1 closed this Apr 28, 2026
@cedriclim1 cedriclim1 reopened this Apr 28, 2026
Copy link
Copy Markdown
Collaborator

@arthurmccray arthurmccray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good! I made the necessary few changes to ptycho so it works with this, and overall i'm happy with it. Hopefully we're one step closer to not having to mess with OptimizerMixin ever again...

As to /core/ml organization, yes cnn.py, cnn_dense.py, inr.py, etc. should all be in /ml/models. I don't think it should be that bad to move them actually? An automatic refactor should handle everything in quantem fine, and they'll be accessible at the same level of the namespace (we still want to be able to from quantem.core.ml import CNN2d), so it shouldn't affeect the tutorials either really. If you move them (and whatever else you think) we can just re-run all the tutorials and tests to make sure nothing breaks.

Comment thread .vscode/settings.json Outdated
"yticks"
],
"basedpyright.analysis.typeCheckingMode": "standard",
"python.REPL.enableREPLSmartSend": false,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is outdated re Bob's comment

import torch
from numpy.typing import NDArray

from quantem.tomography.tomography_context import ReconstructionContext
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to have the ReconstructionContext used in our Constraints ABC, then it shouldn't live in tomography_context.

It's definitely not going to be fully backwards compatible, but hopefully things should run without errors, and we'll only have problems with linting?


class PPLR(ABC):
"""
Abstract base class for models that require multi-scale parameter optimization.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is "multi-scale" the right way of describing this?

Comment on lines +30 to +50
class TensorDecompositionModel(nn.Module, ABC):
"""
Base class for factored tensor-decomposition models.

Subclasses must set ``td_type`` as a normal attribute in ``__init__``.
"""

td_type: str


class PlanarDecompositionModel(TensorDecompositionModel, PPLR):
"""
Planar factored-grid models: K-Planes, K-Planes-TILTED, HexPlane, tri-planes.

Subclasses must set ``grids``, ``tilted``, and ``resolution`` as normal
attributes in ``__init__``.
"""

grids: nn.ParameterList
tilted: bool
resolution: list[int]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the motivation for having the TensorDecomposisionModel base class, shouldn't there just be a td_type in PlanarDecompositionModel? Unless there's going to be other, non-planar decomp moels...

Comment on lines +113 to +115
SO(3) rotation bank using R9+SVD parameterization.
Each rotation is stored as an unconstrained 3x3 matrix M,
projected to SO(3) via SVD+(M) = U diag(1,1,det(UVt)) Vt.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a link to the paper in the docstring

Comment thread src/quantem/__init__.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay the fact that your formatter is changing this file is a little suspicious, are you using a weird ruff setup or anything?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cite the relevant papers please for the models

elif not isinstance(self._model, torch.nn.parallel.DistributedDataParallel):
self.distribute_model(self._model)
self.reconnect_optimizer_to_parameters()
def _get_plane_tv_loss(self) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some explanation for what this is would be helpful. I guess also why you use this and the normal INR style volume TV.

)


def _unwrap(model: nn.Module | nn.parallel.DistributedDataParallel) -> PlanarDecompositionModel:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing linter errors for this function and in a couple other places. Not a huge deal but wanted to let you know

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just did a pretty quick skim over this and it looks good. Main thing is that, since this will be new to most folks, include some resources and more explanation for what's going on. i.e. reference the relevant papers and more code comments or docstrings

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.