-
Notifications
You must be signed in to change notification settings - Fork 211
Test if models are torch scriptable #714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Almost all the models.... i will continue tomorrow |
Almost @PierreGtch, I will need some help to improve the parser of the ch_infos and will be done |
later |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #714 +/- ##
==========================================
- Coverage 87.66% 87.50% -0.16%
==========================================
Files 82 83 +1
Lines 7558 7614 +56
==========================================
+ Hits 6626 6663 +37
- Misses 932 951 +19 🚀 New features to boost your workflow:
|
… into torchscript
almost here @PierreGtch... |
I played a little today, and it looks like it is possible to have Torch script, aka a braindecode model in production. Still, we need to improve the split of the chs info from mne into two separate variable spaces: one with ch_names and the second with a torch tensor variable. A small detail is that we need to implement some validation on the init. Otherwise, the torch script gets very annoyed. from typing import Optional, Dict, List, Tuple, Any
import numpy as np
import torch
from torch import nn
def extract_ch_names(chs_info: List[Dict[str, Any]]) -> List[str]:
channel_names: List[str] = []
for ch in chs_info:
ch_name = ch.get("ch_name")
if isinstance(ch_name, str):
channel_names.append(ch_name)
else:
raise TypeError(f"Expected 'ch_name' to be a string in dict: {ch}")
return channel_names
class EEGModuleMixin:
_ch_names: List[str]
_sfreq: Optional[float]
_n_outputs: Optional[int]
_n_chans: Optional[int]
_n_times: Optional[int]
_input_window_seconds: Optional[float]
# Constants can help JIT optimization
__constants__ = ['_n_outputs', '_n_chans', '_n_times', '_input_window_seconds', '_sfreq']
def __init__(
self,
n_outputs: Optional[int] = None,
n_chans: Optional[int] = None,
chs_info: Optional[List[Dict[str, Any]]] = None, # Input remains List[Dict]
n_times: Optional[int] = None,
input_window_seconds: Optional[float] = None,
sfreq: Optional[float] = None,
):
if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
raise ValueError(
f"{n_chans=} different from {chs_info}={len(chs_info)} length"
)
if (
n_times is not None
and input_window_seconds is not None
and sfreq is not None
# Use tolerance for float comparison if needed, but int conversion is fine here
and n_times != int(input_window_seconds * sfreq)
):
raise ValueError(
f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
)
self._n_outputs = n_outputs
self._n_chans = n_chans
self._n_times = n_times
self._input_window_seconds = input_window_seconds
self._sfreq = sfreq
if self._n_chans is None:
if chs_info is not None:
self._n_chans = len(chs_info)
else:
raise ValueError("Either n_chans or chs_info must be provided to determine channel count.")
final_ch_names: List[str]
if chs_info is not None:
final_ch_names = extract_ch_names(chs_info)
if len(final_ch_names) != self._n_chans:
raise ValueError(f"Length of chs_info ({len(final_ch_names)}) "
f"does not match n_chans ({self._n_chans}).")
else:
final_ch_names = [f"{i}" for i in range(self._n_chans)]
self._ch_names = torch.jit.Attribute(final_ch_names, List[str])
super().__init__()
@property
def n_outputs(self) -> int:
if self._n_outputs is None:
raise ValueError("n_outputs not specified.")
return self._n_outputs
@property
def n_chans(self) -> int:
if self._n_chans is None:
# This should theoretically not happen due to __init__ logic
raise RuntimeError("Internal error: _n_chans was not set.")
return self._n_chans
@property
def ch_names(self) -> List[str]: # Renamed property
return self._ch_names
@property
def n_times(self) -> int:
if self._n_times is None:
if self._input_window_seconds is not None and self._sfreq is not None:
return int(self._input_window_seconds * self._sfreq)
else:
raise ValueError("n_times could not be inferred.")
return self._n_times
@property
def input_window_seconds(self) -> float:
if self._input_window_seconds is None:
if self._n_times is not None and self._sfreq is not None:
if self._sfreq == 0: raise ValueError("sfreq is zero.")
return self._n_times / self._sfreq
else:
raise ValueError("input_window_seconds could not be inferred.")
return self._input_window_seconds
@property
def sfreq(self) -> float:
if self._sfreq is None:
if self._input_window_seconds is not None and self._n_times is not None:
if self._input_window_seconds == 0: raise ValueError("input_window_seconds is zero.")
return self._n_times / self._input_window_seconds
else:
raise ValueError("sfreq could not be inferred.")
return self._sfreq
class DummyModel(EEGModuleMixin, nn.Module):
def __init__(self,
n_chans: Optional[int]=None,
n_outputs: Optional[int]=None,
n_times: Optional[int]=None,
chs_info: Optional[List[Dict[str, Any]]]=None,
input_window_seconds: Optional[float]=None,
sfreq: Optional[float]=None,
):
EEGModuleMixin.__init__(
self,
n_chans=n_chans,
n_outputs=n_outputs,
n_times=n_times,
chs_info=chs_info,
input_window_seconds=input_window_seconds,
sfreq=sfreq,
)
self.dummy_layer = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.dummy_layer(x)
return x
chs_info_data = [{"ch_name": f"ch{i}"} for i in range(10)]
N_CHANS = 10
N_TIMES = 3000 # 30 seconds * 100 Hz
model = DummyModel(
n_chans=N_CHANS,
n_outputs=4,
n_times=N_TIMES,
chs_info=chs_info_data,
input_window_seconds=30.0,
sfreq=100.0,
)
x = torch.randn(1, N_CHANS, N_TIMES)
output = model(x)
print(f"Output shape: {output.shape}")
print("\nAttempting to script the model...")
scripted_model = torch.jit.script(model)
print("Model scripted successfully!")
# Save and load check
torch.jit.save(scripted_model, "dummy_model_scripted.pt")
loaded_scripted_model = torch.jit.load("dummy_model_scripted.pt")
print("\nLoaded scripted model and tested.")
output_loaded = loaded_scripted_model(x)
assert torch.equal(output, output_loaded) |
but reeeeally funny play with this.... |
@bruAristimunha Thé sample you shared is changing the current behaviour. At the moment we can have models that require neither the Let's discuss this PR on Tuesday! |
Just a toy example, I don't plan on changing the behaviour of chs_info. But we need to break ch_info internally into two things: a list of numeric variables, and another list of strings. Torch Script doesn't even handle variables with very complex types with info[chs]. but super happy to chat tomorrow =) |
I will start again here... |
CF #713