Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Add type hints to augmentation/functional.py #709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
Loading
from

Conversation

itsaphel
Copy link
Contributor

TODOs:

  • A couple of places where mypy/pyright get confused, due to reusing the same variable name with a different type.
  • Maybe a couple of places where the type should be tightened?

@itsaphel itsaphel force-pushed the typing/augmentation_functional branch from 5450390 to 4615b4b Compare March 21, 2025 14:48
Comment on lines 909 to 924
def _make_rotation_matrix(
axis: Literal["x", "y", "z"],
angle: float | int | np.ndarray | list,
degrees: bool = True,
) -> torch.Tensor:
assert axis in ["x", "y", "z"], "axis should be either x, y or z."

if isinstance(angle, (float, int, np.ndarray, list)):
angle = torch.as_tensor(angle)
# TODO: else?

if degrees:
angle = angle * np.pi / 180

# TODO: mypy is confused about types here
device = angle.device
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need some stronger type assertions here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by stronger type assertion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what angle is supposed to be. Presumably torch.Tensor is also valid, but not sure if there's other possible types too. That said, do we want to accept so many possible types on this function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm. okay got it your point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made one suggestion for this case, I did not understand in the beginning.

@itsaphel itsaphel mentioned this pull request Mar 21, 2025
77 tasks
@bruAristimunha
Copy link
Collaborator

Can you fix the tests @itsaphel?

@itsaphel
Copy link
Contributor Author

itsaphel commented Mar 24, 2025

@bruAristimunha: I was confused about _make_rotation_matrix, specifically what type is angle supposed to be? Line 916 suggests float | int | np.ndarray | list are permissible, but there's no assert or else, so are other types also OK?

For the tests, can fix. I think it's unhappy with Python 3.10+ syntax that I used, so I can fall back to typing.Union (etc)

@bruAristimunha
Copy link
Collaborator

Good question! hmmmmm, let me think...

I think it is okay to create some assertions to create the type if it is super necessary. I would recommend building or using some util function that informs the type or something and checks for you if this will make mypy happy. Very similar to the scikit-learn approach.

braindecode/augmentation/functional.py Outdated Show resolved Hide resolved
@PierreGtch
Copy link
Collaborator

For the tests, can fix. I think it's unhappy with Python 3.10+ syntax that I used, so I can fall back to typing.Union (etc)

@itsaphel no need to fall back to the typing.Union syntax; you can simply add from __future__ import annotations to the imports (see forum)!

Copy link

codecov bot commented Mar 27, 2025

Codecov Report

Attention: Patch coverage is 94.44444% with 3 lines in your changes missing coverage. Please review.

Project coverage is 87.29%. Comparing base (dbf1647) to head (d98c397).
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #709      +/-   ##
==========================================
- Coverage   87.31%   87.29%   -0.02%     
==========================================
  Files          78       78              
  Lines        7235     7244       +9     
==========================================
+ Hits         6317     6324       +7     
- Misses        918      920       +2     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@itsaphel itsaphel marked this pull request as ready for review March 27, 2025 18:58
Copy link
Collaborator

@PierreGtch PierreGtch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @itsaphel for this work! Just a recuring comment regarding the random state

y: torch.Tensor,
phase_noise_magnitude: float,
channel_indep: bool,
random_state: int | np.random.Generator | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
random_state: int | np.random.Generator | None = None,
random_state: int | np.random.RandomState | None = None,

see check_random_state

X: torch.Tensor,
y: torch.Tensor,
p_drop: float,
random_state: int | np.random.Generator | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@@ -175,7 +196,9 @@ def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
return transformed_X, y


def _pick_channels_randomly(X, p_pick, random_state):
def _pick_channels_randomly(
X: torch.Tensor, p_pick: float, random_state: int | np.random.Generator | None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

c: int,
n: int,
device: torch.device,
random_state: int | np.random.Generator | None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

c: int,
n: int,
device: torch.device,
random_state: int | np.random.Generator | None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@@ -222,7 +250,9 @@ def channels_dropout(X, y, p_drop, random_state=None):
return X * mask.unsqueeze(-1), y


def _make_permutation_matrix(X, mask, random_state):
def _make_permutation_matrix(
X: torch.Tensor, mask: torch.Tensor, random_state: int | np.random.Generator | None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

X: torch.Tensor,
y: torch.Tensor,
p_shuffle: float,
random_state: int | np.random.Generator | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

X: torch.Tensor,
y: torch.Tensor,
std: float,
random_state: int | np.random.Generator | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@@ -160,7 +160,7 @@ class ChannelsDropout(Transform):
----------
probability: float
Float setting the probability of applying the operation.
proba_drop: float | None, optional
p_drop: float | None, optional
Float between 0 and 1 setting the probability of dropping each channel.
Defaults to 0.2.
random_state: int | numpy.random.Generator, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

def _make_rotation_matrix(axis, angle, degrees=True):
def _make_rotation_matrix(
axis: Literal["x", "y", "z"],
angle: float | int | np.ndarray | list,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
angle: float | int | np.ndarray | list,
angle: float | int | np.ndarray | list | torch.Tensor,

@itsaphel
Copy link
Contributor Author

itsaphel commented Mar 27, 2025

Thanks @PierreGtch - I took the randomstate type from the docstring. Should the docstring change here as well?

Looks like np.random.RandomState is the legacy generator? (according to the doc you linked) though I guess if that's what sklearn uses...

@PierreGtch
Copy link
Collaborator

Thanks @PierreGtch - I took the randomstate type from the docstring. Should the docstring change here as well?

Yes, the docstring should be updated for consistency, thanks!

Looks like np.random.RandomState is the legacy generator? (according to the doc you linked) though I guess if that's what sklearn uses...

Indeed, sklearn uses the leagacy RandomState and passing it the new Generator will fail:

from sklearn.utils import check_random_state
import numpy as np
random_state = np.random.default_rng(1) # numpy.random.Generator instance
check_random_state(random_state)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File “~/miniforge3/envs/mne-bids-pipeline/lib/python3.12/site-packages/sklearn/utils/validation.py", line 1518, in check_random_state
    raise ValueError(
ValueError: Generator(PCG64) at 0x13A8166C0 cannot be used to seed a numpy.random.RandomState instance

So we must type accordingly.

@bruAristimunha
Copy link
Collaborator

Hey @itsaphel,

I will solve the model part on the #714. Do you need help here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.