Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Test if models are torch scriptable #714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 49 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
96e8fca
Test if models are torch scriptable
PierreGtch Apr 3, 2025
a6c6fab
testing things
bruAristimunha Apr 18, 2025
7d06958
playing with syncnet
bruAristimunha Apr 18, 2025
40325de
updating
bruAristimunha Apr 18, 2025
431773a
more tests
bruAristimunha Apr 18, 2025
c2bd486
sqrt native and fixing the first model
bruAristimunha Apr 18, 2025
a51efc0
ATCNet is torch_script compatible!
bruAristimunha Apr 18, 2025
59fee80
changing the module
bruAristimunha Apr 18, 2025
45faf53
changing the name
bruAristimunha Apr 18, 2025
deae1b4
atc is torch script compatibly
bruAristimunha Apr 18, 2025
c3a17e9
fixing eegsimpleconv
bruAristimunha Apr 18, 2025
4209663
stable models working with torch script
bruAristimunha Apr 18, 2025
57e97e6
making more models work
bruAristimunha Apr 18, 2025
e37ed4b
fixing
bruAristimunha Apr 18, 2025
9083d41
small adjust
bruAristimunha Apr 18, 2025
1c1b284
more fix
bruAristimunha Apr 18, 2025
3c2fe57
more one model...
bruAristimunha Apr 18, 2025
342628f
Merge branch 'master' into torchscript
bruAristimunha Apr 18, 2025
d393031
final adjustments
bruAristimunha Apr 18, 2025
b37edf1
Merge branch 'torchscript' of https://github.com/PierreGtch/braindeco…
bruAristimunha Apr 18, 2025
788312d
Apply suggestions from code review
bruAristimunha Apr 18, 2025
f2a68c3
resolving eegnet
bruAristimunha Apr 18, 2025
c697426
Merge branch 'master' into torchscript
bruAristimunha Apr 18, 2025
bb3d412
more wip
bruAristimunha Apr 19, 2025
272a66d
dont usleep!
bruAristimunha Apr 19, 2025
fa3918c
wip
bruAristimunha Apr 19, 2025
d74a6dd
replace to already implement module
bruAristimunha Apr 19, 2025
643c42e
typing
bruAristimunha Apr 19, 2025
9af6623
bye clone function
bruAristimunha Apr 19, 2025
8b59034
removing lambda
bruAristimunha Apr 19, 2025
ccf8dab
done eldele
bruAristimunha Apr 19, 2025
ed34c2b
finish sleep models
bruAristimunha Apr 19, 2025
23e6276
small adjustment
bruAristimunha Apr 19, 2025
a2a5e61
done with the ifnet
bruAristimunha Apr 19, 2025
33372c4
final attempt
bruAristimunha Apr 19, 2025
3278460
reverting eegminer
bruAristimunha Apr 19, 2025
3b44bda
from future back
bruAristimunha Apr 19, 2025
166a7c6
too complicated for now
bruAristimunha Apr 19, 2025
b421ac6
updating the sys
bruAristimunha Apr 20, 2025
58a14e0
Merge branch 'master' into torchscript
bruAristimunha Apr 20, 2025
aed627d
updating the whats new file
bruAristimunha Apr 20, 2025
8ac5ecd
Merge branch 'torchscript' of https://github.com/PierreGtch/braindeco…
bruAristimunha Apr 20, 2025
f76e437
skip if windows
bruAristimunha Apr 20, 2025
13090a4
removing eegnetv4
bruAristimunha Apr 20, 2025
fd6e706
tensordict
bruAristimunha Apr 20, 2025
b9c2196
improving the parser
bruAristimunha Apr 20, 2025
7aeffa0
updating base
bruAristimunha Apr 20, 2025
503bb4a
last modification to discuss later
bruAristimunha Apr 21, 2025
a764362
updating typing
bruAristimunha Apr 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions 12 braindecode/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
from .ctnet import CTNet
from .sinc_shallow import SincShallowNet
from .sccnet import SCCNet
from .signal_jepa import (
SignalJEPA_Contextual,
SignalJEPA_PostLocal,
SignalJEPA_PreLocal,
SignalJEPA,
)
from .signal_jepa import ( # type: ignore
SignalJEPA_Contextual, # type: ignore
SignalJEPA_PostLocal, # type: ignore
SignalJEPA_PreLocal, # type: ignore
SignalJEPA, # type: ignore
) # type: ignore
from .fbcnet import FBCNet
from .fbmsnet import FBMSNet
from .fblightconvnet import FBLightConvNet
Expand Down
33 changes: 19 additions & 14 deletions 33 braindecode/models/atcnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Authors: Cedric Rommel <cedric.rommel@inria.fr>
#
# License: BSD (3-clause)
import numpy as np
import math
from typing import List

import torch
from einops.layers.torch import Rearrange
from torch import nn
Expand Down Expand Up @@ -247,39 +249,42 @@ def forward(self, X):
# Dimension: (batch_size, F2, Tc)

# ----- Sliding window -----
sw_concat = [] # to store sliding window outputs
for w in range(self.n_windows):
conv_feat_w = conv_feat[..., w : w + self.Tw]
sw_concat: List[torch.Tensor] = [] # to store sliding window outputs
# for w in range(self.n_windows):
for idx, (attention, tcn_module, final_layer) in enumerate(
zip(self.attention_blocks, self.temporal_conv_nets, self.final_layer)
):
conv_feat_w = conv_feat[..., idx : idx + self.Tw]
# Dimension: (batch_size, F2, Tw)

# ----- Attention block -----
att_feat = self.attention_blocks[w](conv_feat_w)
att_feat = attention(conv_feat_w)
# Dimension: (batch_size, F2, Tw)

# ----- Temporal convolutional network (TCN) -----
tcn_feat = self.temporal_conv_nets[w](att_feat)[..., -1]
tcn_feat = tcn_module(att_feat)[..., -1]
# Dimension: (batch_size, F2)

# Outputs of sliding window can be either averaged after being
# mapped by dense layer or concatenated then mapped by a dense
# layer
if not self.concat:
tcn_feat = self.final_layer[w](tcn_feat)
tcn_feat = final_layer(tcn_feat)

sw_concat.append(tcn_feat)

# ----- Aggregation and prediction -----
if self.concat:
sw_concat = torch.cat(sw_concat, dim=1)
sw_concat = self.final_layer[0](sw_concat)
sw_concat_agg = torch.cat(sw_concat, dim=1)
sw_concat_agg = self.final_layer[0](sw_concat_agg)
else:
if len(sw_concat) > 1: # more than one window
sw_concat = torch.stack(sw_concat, dim=0)
sw_concat = torch.mean(sw_concat, dim=0)
sw_concat_agg = torch.stack(sw_concat, dim=0)
sw_concat_agg = torch.mean(sw_concat_agg, dim=0)
else: # one window (# windows = 1)
sw_concat = sw_concat[0]
sw_concat_agg = sw_concat[0]

return self.out_fun(sw_concat)
return self.out_fun(sw_concat_agg)


class _ConvBlock(nn.Module):
Expand Down Expand Up @@ -629,7 +634,7 @@ def forward(
# Attention weights of size (num_heads * batch_size, n, m):
# measures how similar each pair of Q and K is.
W = torch.softmax(
Q_.bmm(K_.transpose(-2, -1)) / np.sqrt(self.head_dim),
Q_.bmm(K_.transpose(-2, -1)) / math.sqrt(self.head_dim),
-1, # (B', D', S)
) # (B', N, M)

Expand Down
90 changes: 55 additions & 35 deletions 90 braindecode/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
# Maciej Sliwowski
#
# License: BSD-3

from __future__ import annotations
import warnings
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple, Any

from collections import OrderedDict

import numpy as np
import torch
from tensordict import TensorDict

from docstring_inheritance import NumpyDocstringInheritanceInitMeta
from torchinfo import ModelStatistics, summary
from braindecode.util import chs_to_torch


def deprecated_args(obj, *old_new_args):
Expand Down Expand Up @@ -73,18 +75,37 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
there will be an attempt to infer them from the other parameters.
"""

# _chs_info: List[str] # ict[str, np.array | str | int | float]] # Stores List[str] after conversion
_n_outputs: int
_n_chans: int
_n_times: int
# _sfreq: float
# _input_window_seconds: float
_add_log_softmax: bool

__constants__ = [
"_n_outputs",
"_n_chans",
"_n_times",
"_input_window_seconds",
# "_sfreq",
"_add_log_softmax",
]

def __init__(
self,
n_outputs: Optional[int] = None,
n_chans: Optional[int] = None,
chs_info: Optional[List[Dict]] = None,
chs_info=None,
n_times: Optional[int] = None,
input_window_seconds: Optional[float] = None,
sfreq: Optional[float] = None,
sfreq: float = None, # type: ignore
add_log_softmax: Optional[bool] = False,
):
if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
raise ValueError(f"{n_chans=} different from {chs_info=} length")
raise ValueError(
f"{n_chans=} different from {chs_info}={len(chs_info)} length"
)
if (
n_times is not None
and input_window_seconds is not None
Expand All @@ -94,23 +115,31 @@ def __init__(
raise ValueError(
f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
)
self._n_outputs = n_outputs
self._n_chans = n_chans
self._chs_info = chs_info
self._n_times = n_times
self._input_window_seconds = input_window_seconds
self._sfreq = sfreq
self._add_log_softmax = add_log_softmax
# if chs_info is not None:

if input_window_seconds is not None:
self._input_window_seconds = float(input_window_seconds) # type: ignore[assignment]
else:
self._input_window_seconds = input_window_seconds # type: ignore[assignment]

self._chs_info = torch.jit.Attribute("[]", List[str])
# chs_info # torch.jit.Attribute(chs_to_torch(), List[str]) # type: ignore[assignment]

self._n_outputs = n_outputs # type: ignore[assignment]
self._n_chans = n_chans # type: ignore[assignment]
self._n_times = n_times # type: ignore[assignment]
self._sfreq = float(sfreq) # type: ignore[assignment]
self._add_log_softmax = add_log_softmax # type: ignore[assignment]
super().__init__()

@property
def n_outputs(self):
def n_outputs(self) -> int:
if self._n_outputs is None:
raise ValueError("n_outputs not specified.")
return self._n_outputs

@property
def n_chans(self):
def n_chans(self) -> int:
if self._n_chans is None and self._chs_info is not None:
return len(self._chs_info)
elif self._n_chans is None:
Expand All @@ -120,13 +149,13 @@ def n_chans(self):
return self._n_chans

@property
def chs_info(self):
def chs_info(self) -> List[str]:
if self._chs_info is None:
raise ValueError("chs_info not specified.")
return self._chs_info
return self._chs_info # chs_to_torch(

@property
def n_times(self):
def n_times(self) -> int:
if (
self._n_times is None
and self._input_window_seconds is not None
Expand All @@ -141,13 +170,13 @@ def n_times(self):
return self._n_times

@property
def input_window_seconds(self):
def input_window_seconds(self) -> float:
if (
self._input_window_seconds is None
and self._n_times is not None
and self._sfreq is not None
):
return self._n_times / self._sfreq
return float(self._n_times / self._sfreq)
elif self._input_window_seconds is None:
raise ValueError(
"input_window_seconds could not be inferred. "
Expand All @@ -156,13 +185,13 @@ def input_window_seconds(self):
return self._input_window_seconds

@property
def sfreq(self):
def sfreq(self) -> float:
if (
self._sfreq is None
and self._input_window_seconds is not None
and self._n_times is not None
):
return self._n_times // self._input_window_seconds
return float(self._n_times / self._input_window_seconds)
elif self._sfreq is None:
raise ValueError(
"sfreq could not be inferred. "
Expand All @@ -171,7 +200,7 @@ def sfreq(self):
return self._sfreq

@property
def add_log_softmax(self):
def add_log_softmax(self) -> Optional[bool]:
if self._add_log_softmax:
warnings.warn(
"LogSoftmax final layer will be removed! "
Expand All @@ -195,11 +224,11 @@ def get_output_shape(self) -> Tuple[int, ...]:
with torch.inference_mode():
try:
return tuple(
self.forward(
self.forward( # type: ignore[attr-defined]
torch.zeros(
self.input_shape,
dtype=next(self.parameters()).dtype,
device=next(self.parameters()).device,
dtype=next(self.parameters()).dtype, # type: ignore[attr-defined]
device=next(self.parameters()).device, # type: ignore[attr-defined]
)
).shape
)
Expand Down Expand Up @@ -257,7 +286,7 @@ def to_dense_prediction_model(self, axis: Tuple[int, ...] | int = (2, 3)) -> Non
assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis" # type: ignore[union-attr]
axis = np.array(axis) - 2
stride_so_far = np.array([1, 1])
for module in self.modules():
for module in self.modules(): # type: ignore[attr-defined]
if hasattr(module, "dilation"):
assert module.dilation == 1 or (module.dilation == (1, 1)), (
"Dilation should equal 1 before conversion, maybe the model is "
Expand Down Expand Up @@ -312,12 +341,3 @@ def get_torchinfo_statistics(

def __str__(self) -> str:
return str(self.get_torchinfo_statistics())

def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)

def parameters(self):
return super().parameters()

def modules(self):
return super().modules()
6 changes: 4 additions & 2 deletions 6 braindecode/models/biot.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def stft(self, sample):
)
return torch.abs(spectral)

def forward(self, x, n_channel_offset=0, perturb=False):
def forward(
self, x: torch.Tensor, n_channel_offset: int = 0, perturb: bool = False
) -> torch.Tensor:
"""
Forward pass of the BIOT encoder.

Expand Down Expand Up @@ -478,6 +480,6 @@ def forward(self, x):
x = self.final_layer(emb)

if self.return_feature:
return x, emb
return emb
else:
return x
20 changes: 10 additions & 10 deletions 20 braindecode/models/contrawr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, List
from typing import Any, List, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -162,18 +162,18 @@ class ContraWR(EEGModuleMixin, nn.Module):

def __init__(
self,
n_chans: int | None = None,
n_outputs: int | None = None,
sfreq: int | None = None,
n_chans=None,
n_outputs=None,
sfreq=None,
emb_size: int = 256,
res_channels: list[int] = [32, 64, 128],
steps=20,
steps: int = 20,
activation: nn.Module = nn.ELU,
drop_prob: float = 0.5,
# Another way to pass the EEG parameters
chs_info: list[dict[Any, Any]] | None = None,
n_times: int | None = None,
input_window_seconds: float | None = None,
chs_info=None,
n_times=None,
input_window_seconds=None,
):
super().__init__(
n_outputs=n_outputs,
Expand Down Expand Up @@ -258,7 +258,7 @@ def forward(self, X):
"""
X = self.torch_stft(X)

for conv in self.convs[:-1]:
X = conv(X)
for _, conv in enumerate(self.convs[:-1]):
X = conv.forward(X)
emb = self.convs[-1](X).squeeze(-1).squeeze(-1)
return self.final_layer(emb)
10 changes: 5 additions & 5 deletions 10 braindecode/models/ctnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def __init__(self, module: nn.Module, emb_size: int, drop_p: float):
self.drop = nn.Dropout(drop_p)
self.layernorm = nn.LayerNorm(emb_size)

def forward(self, x: Tensor, **kwargs) -> Tensor:
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass with residual connection.

Expand All @@ -310,7 +310,7 @@ def forward(self, x: Tensor, **kwargs) -> Tensor:
Tensor
Output tensor after applying residual connection.
"""
res = self.module(x, **kwargs)
res = self.module(x)
out = self.layernorm(self.drop(res) + x)
return out

Expand Down Expand Up @@ -346,7 +346,7 @@ def __init__(
drop_prob,
)

def forward(self, x: Tensor, **kwargs) -> Tensor:
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the transformer encoder block.

Expand All @@ -362,8 +362,8 @@ def forward(self, x: Tensor, **kwargs) -> Tensor:
Tensor
Output tensor after transformer encoder block.
"""
x = self.attention(x, **kwargs)
x = self.feed_forward(x, **kwargs)
x = self.attention(x)
x = self.feed_forward(x)
return x


Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.