Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1525b4f

Browse filesBrowse files
committed
fix test
1 parent fcf3155 commit 1525b4f
Copy full SHA for 1525b4f

File tree

Expand file treeCollapse file tree

2 files changed

+13
-2
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

2 files changed

+13
-2
lines changed
Open diff view settings
Collapse file

‎torchmeta/datasets/bach.py‎

Copy file name to clipboardExpand all lines: torchmeta/datasets/bach.py
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ def __init__(self, root, meta_train=False, meta_val=False, meta_test=False, meta
172172
if download:
173173
self.download(process_features, min_num_samples_per_class)
174174

175+
if min_num_samples_per_class != self.meta_data["min_num_data_per_class"]:
176+
raise ValueError("min_num_samples_per_class given ({0}) does not match existing value"
177+
"({1}).".format(min_num_samples_per_class, self.meta_data["min_num_data_per_class"]))
178+
175179
if not self._check_integrity():
176180
raise RuntimeError('Bach integrity check failed')
177181
self._num_classes = len(self.labels)
Collapse file

‎torchmeta/tests/datasets/test_datasets_helpers_tabular.py‎

Copy file name to clipboardExpand all lines: torchmeta/tests/datasets/test_datasets_helpers_tabular.py
+9-2Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,17 @@
88

99
is_local = (os.getenv('TORCHMETA_DATA_FOLDER') is not None)
1010

11+
list_of_datasets = helpers_tabular.__all__
12+
# `bach' is a bit more tricky to test as it has classes with very little data which are
13+
# dropped during pre-processing according to the parameter `min_num_samples_per_class'.
14+
# Hence the test is only applicable if the downloaded dataset was processed with
15+
# min_num_samples_per_class = shots + test_shots.
16+
list_of_datasets.remove('bach')
17+
1118

1219
@pytest.mark.skipif(not is_local, reason='Requires datasets downloaded locally')
13-
@pytest.mark.parametrize('name', helpers_tabular.__all__)
14-
@pytest.mark.parametrize('shots', [5, 1]) # large number first for `bach' dataset.
20+
@pytest.mark.parametrize('name', list_of_datasets)
21+
@pytest.mark.parametrize('shots', [1, 5])
1522
@pytest.mark.parametrize('split', ['train', 'val', 'test'])
1623
def test_datasets_helpers_tabular(name, shots, split):
1724
function = getattr(helpers_tabular, name)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.