Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Test if models are torch scriptable #714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 49 commits into from

Conversation

PierreGtch
Copy link
Collaborator

CF #713

@bruAristimunha
Copy link
Collaborator

bruAristimunha commented Apr 18, 2025

Labram Failed scripted Labram: 
Class Rearrange does not have an __init__ function defined:
  File "/Users/baristim/Projects/braindecode/braindecode/models/labram.py", line 645
            # as output, which keeps channel information
            # This treats each patch embedding as a feature alongside channels
            x = Rearrange(
                ~~~~~~~~~ <--- HERE
                pattern="(batch nchans) embed npatchs -> batch nchans npatchs embed",
                batch=batch_size,
SignalJEPA_Contextual Failed scripted SignalJEPA_Contextual: not enough values to unpack (expected 2, got 0)
SignalJEPA_PostLocal Failed scripted SignalJEPA_PostLocal: Unknown type annotation: '_ConvFeatureEncoder | None' at 
SignalJEPA_PreLocal Failed scripted SignalJEPA_PreLocal: Unknown type annotation: '_ConvFeatureEncoder | None' at 

BIOT Failed scripted BIOT: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/Users/baristim/miniforge3/envs/braindecode/lib/python3.10/site-packages/linear_attention_transformer/linear_attention_transformer.py", line 437
    def forward(self, x, **kwargs):
                          ~~~~~~~ <--- HERE
        return self.layers(x, **kwargs)

EEGMiner Failed scripted EEGMiner: 
Module 'EEGMiner' has no attribute '_chs_info' (This attribute exists on the Python module, but we failed to convert Python type: 'list' to a TorchScript type. List trace inputs must have elements. Its type was inferred; try adding a type annotation for the attribute.):
  File "/Users/baristim/Projects/braindecode/braindecode/models/base.py", line 147
    @torch.jit.ignore
    def chs_info(self) -> List[Dict[str, torch.Tensor]]:
        if self._chs_info is None:
           ~~~~~~~~~~~~~~ <--- HERE
            raise ValueError("chs_info not specified.")
        return self._chs_info

Almost all the models.... i will continue tomorrow

test/unit_tests/models/test_integration.py Outdated Show resolved Hide resolved
@bruAristimunha
Copy link
Collaborator

Almost @PierreGtch, I will need some help to improve the parser of the ch_infos and will be done

@bruAristimunha
Copy link
Collaborator

later

Copy link

codecov bot commented Apr 19, 2025

Codecov Report

Attention: Patch coverage is 90.64516% with 29 lines in your changes missing coverage. Please review.

Project coverage is 87.50%. Comparing base (65d8aa3) to head (f76e437).

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #714      +/-   ##
==========================================
- Coverage   87.66%   87.50%   -0.16%     
==========================================
  Files          82       83       +1     
  Lines        7558     7614      +56     
==========================================
+ Hits         6626     6663      +37     
- Misses        932      951      +19     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@bruAristimunha
Copy link
Collaborator

almost here @PierreGtch...

@bruAristimunha
Copy link
Collaborator

I played a little today, and it looks like it is possible to have Torch script, aka a braindecode model in production. Still, we need to improve the split of the chs info from mne into two separate variable spaces: one with ch_names and the second with a torch tensor variable. A small detail is that we need to implement some validation on the init. Otherwise, the torch script gets very annoyed.

from typing import Optional, Dict, List, Tuple, Any
import numpy as np
import torch
from torch import nn

def extract_ch_names(chs_info: List[Dict[str, Any]]) -> List[str]:
    channel_names: List[str] = []
    for ch in chs_info:
        ch_name = ch.get("ch_name")
        if isinstance(ch_name, str):
            channel_names.append(ch_name)
        else:
            raise TypeError(f"Expected 'ch_name' to be a string in dict: {ch}")
    return channel_names

class EEGModuleMixin:

    _ch_names: List[str]
    _sfreq: Optional[float]
    _n_outputs: Optional[int]
    _n_chans: Optional[int]
    _n_times: Optional[int]
    _input_window_seconds: Optional[float]

    # Constants can help JIT optimization
    __constants__ = ['_n_outputs', '_n_chans', '_n_times', '_input_window_seconds', '_sfreq']

    def __init__(
        self,
        n_outputs: Optional[int] = None,
        n_chans: Optional[int] = None,
        chs_info: Optional[List[Dict[str, Any]]] = None, # Input remains List[Dict]
        n_times: Optional[int] = None,
        input_window_seconds: Optional[float] = None,
        sfreq: Optional[float] = None,
    ):
        if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
            raise ValueError(
                f"{n_chans=} different from {chs_info}={len(chs_info)} length"
            )
        if (
            n_times is not None
            and input_window_seconds is not None
            and sfreq is not None
            # Use tolerance for float comparison if needed, but int conversion is fine here
            and n_times != int(input_window_seconds * sfreq)
        ):
            raise ValueError(
                f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
            )

        self._n_outputs = n_outputs
        self._n_chans = n_chans 
        self._n_times = n_times
        self._input_window_seconds = input_window_seconds
        self._sfreq = sfreq

        if self._n_chans is None:
            if chs_info is not None:
                self._n_chans = len(chs_info)
            else:
                raise ValueError("Either n_chans or chs_info must be provided to determine channel count.")

        final_ch_names: List[str]
        if chs_info is not None:
            final_ch_names = extract_ch_names(chs_info)
            if len(final_ch_names) != self._n_chans:
                 raise ValueError(f"Length of chs_info ({len(final_ch_names)}) "
                                  f"does not match n_chans ({self._n_chans}).")
        else:
            final_ch_names = [f"{i}" for i in range(self._n_chans)] 

        self._ch_names = torch.jit.Attribute(final_ch_names, List[str])


        super().__init__()

    @property
    def n_outputs(self) -> int:
        if self._n_outputs is None:
            raise ValueError("n_outputs not specified.")
        return self._n_outputs

    @property
    def n_chans(self) -> int:
        if self._n_chans is None:
             # This should theoretically not happen due to __init__ logic
             raise RuntimeError("Internal error: _n_chans was not set.")
        return self._n_chans

    @property
    def ch_names(self) -> List[str]: # Renamed property
        return self._ch_names

    @property
    def n_times(self) -> int:
        if self._n_times is None:
            if self._input_window_seconds is not None and self._sfreq is not None:
                return int(self._input_window_seconds * self._sfreq)
            else:
                 raise ValueError("n_times could not be inferred.")
        return self._n_times

    @property
    def input_window_seconds(self) -> float:
        if self._input_window_seconds is None:
             if self._n_times is not None and self._sfreq is not None:
                 if self._sfreq == 0: raise ValueError("sfreq is zero.")
                 return self._n_times / self._sfreq
             else:
                 raise ValueError("input_window_seconds could not be inferred.")
        return self._input_window_seconds

    @property
    def sfreq(self) -> float:
        if self._sfreq is None:
            if self._input_window_seconds is not None and self._n_times is not None:
                 if self._input_window_seconds == 0: raise ValueError("input_window_seconds is zero.")
                 return self._n_times / self._input_window_seconds
            else:
                 raise ValueError("sfreq could not be inferred.")
        return self._sfreq


class DummyModel(EEGModuleMixin, nn.Module):
    def __init__(self,
        n_chans: Optional[int]=None,
        n_outputs: Optional[int]=None,
        n_times: Optional[int]=None,
        chs_info: Optional[List[Dict[str, Any]]]=None, 
        input_window_seconds: Optional[float]=None,
        sfreq: Optional[float]=None,
    ):

        EEGModuleMixin.__init__(
            self, 
            n_chans=n_chans,
            n_outputs=n_outputs,
            n_times=n_times,
            chs_info=chs_info,
            input_window_seconds=input_window_seconds,
            sfreq=sfreq,
        )
        self.dummy_layer = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.dummy_layer(x)
        return x

chs_info_data = [{"ch_name": f"ch{i}"} for i in range(10)]
N_CHANS = 10
N_TIMES = 3000 # 30 seconds * 100 Hz

model = DummyModel(
    n_chans=N_CHANS,
    n_outputs=4,
    n_times=N_TIMES,
    chs_info=chs_info_data,
    input_window_seconds=30.0,
    sfreq=100.0,
)
x = torch.randn(1, N_CHANS, N_TIMES)


output = model(x)
print(f"Output shape: {output.shape}")

print("\nAttempting to script the model...")
scripted_model = torch.jit.script(model)
print("Model scripted successfully!")

# Save and load check
torch.jit.save(scripted_model, "dummy_model_scripted.pt")
loaded_scripted_model = torch.jit.load("dummy_model_scripted.pt")
print("\nLoaded scripted model and tested.")
output_loaded = loaded_scripted_model(x)
assert torch.equal(output, output_loaded)

@bruAristimunha
Copy link
Collaborator

but reeeeally funny play with this....

@PierreGtch
Copy link
Collaborator Author

@bruAristimunha Thé sample you shared is changing the current behaviour. At the moment we can have models that require neither the chs_info nor the n_chans and I think we should leave this possibility open

Let's discuss this PR on Tuesday!

@bruAristimunha
Copy link
Collaborator

Just a toy example, I don't plan on changing the behaviour of chs_info. But we need to break ch_info internally into two things: a list of numeric variables, and another list of strings. Torch Script doesn't even handle variables with very complex types with info[chs]. but super happy to chat tomorrow =)

@bruAristimunha
Copy link
Collaborator

I will start again here...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.