-
Notifications
You must be signed in to change notification settings - Fork 211
Add type hints to augmentation/functional.py
#709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add type hints to augmentation/functional.py
#709
Conversation
5450390
to
4615b4b
Compare
def _make_rotation_matrix( | ||
axis: Literal["x", "y", "z"], | ||
angle: float | int | np.ndarray | list, | ||
degrees: bool = True, | ||
) -> torch.Tensor: | ||
assert axis in ["x", "y", "z"], "axis should be either x, y or z." | ||
|
||
if isinstance(angle, (float, int, np.ndarray, list)): | ||
angle = torch.as_tensor(angle) | ||
# TODO: else? | ||
|
||
if degrees: | ||
angle = angle * np.pi / 180 | ||
|
||
# TODO: mypy is confused about types here | ||
device = angle.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe need some stronger type assertions here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by stronger type assertion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what angle
is supposed to be. Presumably torch.Tensor
is also valid, but not sure if there's other possible types too. That said, do we want to accept so many possible types on this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm. okay got it your point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made one suggestion for this case, I did not understand in the beginning.
Can you fix the tests @itsaphel? |
@bruAristimunha: I was confused about For the tests, can fix. I think it's unhappy with Python 3.10+ syntax that I used, so I can fall back to |
Good question! hmmmmm, let me think... I think it is okay to create some assertions to create the type if it is super necessary. I would recommend building or using some util function that informs the type or something and checks for you if this will make mypy happy. Very similar to the scikit-learn approach. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #709 +/- ##
==========================================
- Coverage 87.31% 87.29% -0.02%
==========================================
Files 78 78
Lines 7235 7244 +9
==========================================
+ Hits 6317 6324 +7
- Misses 918 920 +2 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @itsaphel for this work! Just a recuring comment regarding the random state
y: torch.Tensor, | ||
phase_noise_magnitude: float, | ||
channel_indep: bool, | ||
random_state: int | np.random.Generator | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
random_state: int | np.random.Generator | None = None, | |
random_state: int | np.random.RandomState | None = None, |
X: torch.Tensor, | ||
y: torch.Tensor, | ||
p_drop: float, | ||
random_state: int | np.random.Generator | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
@@ -175,7 +196,9 @@ def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None): | ||
return transformed_X, y | ||
|
||
|
||
def _pick_channels_randomly(X, p_pick, random_state): | ||
def _pick_channels_randomly( | ||
X: torch.Tensor, p_pick: float, random_state: int | np.random.Generator | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
c: int, | ||
n: int, | ||
device: torch.device, | ||
random_state: int | np.random.Generator | None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
c: int, | ||
n: int, | ||
device: torch.device, | ||
random_state: int | np.random.Generator | None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
@@ -222,7 +250,9 @@ def channels_dropout(X, y, p_drop, random_state=None): | ||
return X * mask.unsqueeze(-1), y | ||
|
||
|
||
def _make_permutation_matrix(X, mask, random_state): | ||
def _make_permutation_matrix( | ||
X: torch.Tensor, mask: torch.Tensor, random_state: int | np.random.Generator | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
X: torch.Tensor, | ||
y: torch.Tensor, | ||
p_shuffle: float, | ||
random_state: int | np.random.Generator | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
X: torch.Tensor, | ||
y: torch.Tensor, | ||
std: float, | ||
random_state: int | np.random.Generator | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
@@ -160,7 +160,7 @@ class ChannelsDropout(Transform): | ||
---------- | ||
probability: float | ||
Float setting the probability of applying the operation. | ||
proba_drop: float | None, optional | ||
p_drop: float | None, optional | ||
Float between 0 and 1 setting the probability of dropping each channel. | ||
Defaults to 0.2. | ||
random_state: int | numpy.random.Generator, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
def _make_rotation_matrix(axis, angle, degrees=True): | ||
def _make_rotation_matrix( | ||
axis: Literal["x", "y", "z"], | ||
angle: float | int | np.ndarray | list, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
angle: float | int | np.ndarray | list, | |
angle: float | int | np.ndarray | list | torch.Tensor, |
Thanks @PierreGtch - I took the randomstate type from the docstring. Should the docstring change here as well? Looks like |
Yes, the docstring should be updated for consistency, thanks!
Indeed, sklearn uses the leagacy from sklearn.utils import check_random_state
import numpy as np
random_state = np.random.default_rng(1) # numpy.random.Generator instance
check_random_state(random_state)
So we must type accordingly. |
TODOs: