Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9ddfc1a

Browse filesBrowse files
committed
adding letter dataset
1 parent 5f6f700 commit 9ddfc1a
Copy full SHA for 9ddfc1a

File tree

Expand file treeCollapse file tree

6 files changed

+331
-2
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

6 files changed

+331
-2
lines changed
Open diff view settings
Collapse file

‎torchmeta/datasets/__init__.py‎

Copy file name to clipboardExpand all lines: torchmeta/datasets/__init__.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchmeta.datasets.tcga import TCGA
99
from torchmeta.datasets.pascal5i import Pascal5i
1010
from torchmeta.datasets.covertype_task_id_2118 import Covertype
11+
from torchmeta.datasets.letter_task_id_6 import Letter
1112

1213
from torchmeta.datasets import helpers
1314
from torchmeta.datasets import helpers_tabular
@@ -25,5 +26,6 @@
2526
'Pascal5i',
2627
'helpers',
2728
'Covertype',
29+
'Letter',
2830
'helpers_tabular'
2931
]
Collapse file
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
["20", "7", "10", "14", "19", "6"]
Collapse file
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
["8", "16", "0", "24", "11", "9", "13", "1", "23", "5", "2", "12", "15", "3", "4"]
Collapse file
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
["21", "17", "22", "18", "25"]
Collapse file

‎torchmeta/datasets/helpers_tabular.py‎

Copy file name to clipboardExpand all lines: torchmeta/datasets/helpers_tabular.py
+30-2Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import warnings
22

3-
from torchmeta.datasets import Covertype
3+
from torchmeta.datasets import Covertype, Letter
44
from torchmeta.transforms import Categorical, ClassSplitter
55
from torchmeta.transforms.tabular_transforms import NumpyToTorch
66

77
__all__ = [
8-
'covertype'
8+
'covertype',
9+
'letter'
910
]
1011

1112

@@ -104,3 +105,30 @@ def covertype(folder: str, shots: int, ways: int, shuffle: bool=True,
104105
"enough classes in the split.".format(ways))
105106
return helper_with_default_tabular(Covertype, folder, shots, ways, shuffle=shuffle,
106107
test_shots=test_shots, seed=seed, defaults=None, **kwargs)
108+
109+
110+
def letter(folder: str, shots: int, ways: int, shuffle: bool=True,
111+
test_shots: int=None, seed: int=None, **kwargs) -> Letter:
112+
"""
113+
Wrapper that creates a meta-dataset for the Letter Image Recognition dataset.
114+
115+
Todo fix docstring
116+
Notes
117+
--------
118+
Letter has 26 classes in total with splits train/val/test : 3/2/2.
119+
120+
The ClassDataset currently uses benchlib to load the original dataset.
121+
It might be better to directly load it from open-ml in the future.
122+
https://code.amazon.com/packages/Benchlib/trees/mainline
123+
124+
See also
125+
--------
126+
`datasets.Letter` : Meta-dataset for the Letter dataset.
127+
"""
128+
if ways > 3:
129+
warnings.warn("The number of ways is ({0}), but the default splits train/val/test "
130+
"contain only 3/2/2 classes respectively. Unless you use a custom"
131+
"split, are label augmentation, it may be possible that there are not "
132+
"enough classes in the split.".format(ways))
133+
return helper_with_default_tabular(Letter, folder, shots, ways, shuffle=shuffle,
134+
test_shots=test_shots, seed=seed, defaults=None, **kwargs)
Collapse file
+296Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
import numpy as np
2+
import os
3+
import json
4+
import h5py
5+
from tqdm import tqdm
6+
7+
from torchmeta.utils.data import Dataset, ClassDataset, CombinationMetaDataset
8+
from torchmeta.datasets.utils import get_asset
9+
10+
11+
class Letter(CombinationMetaDataset):
12+
"""The Letter Image Recognition Dataset """
13+
def __init__(self, root, num_classes_per_task=None, meta_train=False, meta_val=False, meta_test=False,
14+
meta_split=None, transform=None, target_transform=None, dataset_transform=None,
15+
class_augmentations=None, download=False):
16+
"""
17+
Letter Image Recognition Data [1]:
18+
19+
https://archive.ics.uci.edu/ml/datasets/Letter+Recognition - 01-01-1991
20+
21+
The objective is to identify each of a large number of black-and-white
22+
rectangular pixel displays as one of the 26 capital letters in the English
23+
alphabet. The character images were based on 20 different fonts and each
24+
letter within these 20 fonts was randomly distorted to produce a file of
25+
20,000 unique stimuli. Each stimulus was converted into 16 primitive
26+
numerical attributes (statistical moments and edge counts) which were then
27+
scaled to fit into a range of integer values from 0 through 15. We
28+
typically train on the first 16000 items and then use the resulting model
29+
to predict the letter category for the remaining 4000. See the article
30+
cited above for more details.
31+
32+
The dataset is loaded and processed with benchlib. Originally it is from open-ml.
33+
https://www.openml.org/d/6
34+
35+
Parameters
36+
----------
37+
root : string
38+
Root directory where the dataset folder `letter_task_id_6` exists.
39+
40+
num_classes_per_task : int
41+
Number of classes per tasks. This corresponds to "N" in "N-way"
42+
classification.
43+
44+
meta_train : bool (default: `False`)
45+
Use the meta-train split of the dataset. If set to `True`, then the
46+
arguments `meta_val` and `meta_test` must be set to `False`. Exactly one
47+
of these three arguments must be set to `True`.
48+
49+
meta_val : bool (default: `False`)
50+
Use the meta-validation split of the dataset. If set to `True`, then the
51+
arguments `meta_train` and `meta_test` must be set to `False`. Exactly
52+
one of these three arguments must be set to `True`.
53+
54+
meta_test : bool (default: `False`)
55+
Use the meta-test split of the dataset. If set to `True`, then the
56+
arguments `meta_train` and `meta_val` must be set to `False`. Exactly
57+
one of these three arguments must be set to `True`.
58+
59+
meta_split : string in {'train', 'val', 'test'}, optional
60+
Name of the split to use. This overrides the arguments `meta_train`,
61+
`meta_val` and `meta_test` if all three are set to `False`.
62+
63+
transform : callable, optional
64+
A function/transform that takes a numpy array or a pytorch array
65+
(depending when the transforms is applied), and returns a transformed
66+
version.
67+
68+
target_transform : callable, optional
69+
A function/transform that takes a target, and returns a transformed
70+
version.
71+
72+
dataset_transform : callable, optional
73+
A function/transform that takes a dataset (ie. a task), and returns a
74+
transformed version of it. E.g. `torchmeta.transforms.ClassSplitter()`.
75+
76+
class_augmentations : list of callable, optional
77+
A list of functions that augment the dataset with new classes. These
78+
classes are transformations of existing classes.
79+
80+
download : bool (default: `False`)
81+
If `True`, downloads the original files and processes the dataset in the
82+
root directory (under the `letter_task_id_6` folder). If the dataset
83+
is already available, this does not download/process the dataset again.
84+
85+
References
86+
-----
87+
[1] P. W. Frey and D. J. Slate. "Letter Recognition Using Holland-style
88+
Adaptive Classifiers". Machine Learning 6(2), 1991
89+
"""
90+
dataset = LetterClassDataset(root,
91+
meta_train=meta_train,
92+
meta_val=meta_val,
93+
meta_test=meta_test,
94+
meta_split=meta_split,
95+
transform=transform,
96+
class_augmentations=class_augmentations,
97+
download=download)
98+
super(Letter, self).__init__(dataset,
99+
num_classes_per_task,
100+
target_transform=target_transform,
101+
dataset_transform=dataset_transform)
102+
103+
104+
class LetterClassDataset(ClassDataset):
105+
106+
benchlib_namespace = "openml_datasets"
107+
benchlib_dataset_name = "letter_task_id_6"
108+
109+
folder = "letter_task_id_6"
110+
filename = '{0}_data.hdf5'
111+
filename_labels = '{0}_labels.json'
112+
113+
def __init__(self, root, meta_train=False, meta_val=False, meta_test=False, meta_split=None, transform=None,
114+
class_augmentations=None, download=False):
115+
super(LetterClassDataset, self).__init__(meta_train=meta_train, meta_val=meta_val, meta_test=meta_test,
116+
meta_split=meta_split, class_augmentations=class_augmentations)
117+
118+
self.root = os.path.join(os.path.expanduser(root), self.folder)
119+
self.transform = transform
120+
121+
self.split_filename = os.path.join(self.root, self.filename.format(self.meta_split))
122+
self.split_filename_labels = os.path.join(self.root, self.filename_labels.format(self.meta_split))
123+
124+
self._data_file = None
125+
self._data = None
126+
self._labels = None
127+
128+
if download:
129+
self.download()
130+
131+
if not self._check_integrity():
132+
raise RuntimeError('Letter integrity check failed')
133+
self._num_classes = len(self.labels)
134+
135+
def __getitem__(self, index):
136+
label = self.labels[index % self.num_classes]
137+
data = self.data[label]
138+
transform = self.get_transform(index, self.transform)
139+
target_transform = self.get_target_transform(index)
140+
141+
return LetterDataset(index, data, label, transform=transform, target_transform=target_transform)
142+
143+
@property
144+
def num_classes(self):
145+
return self._num_classes
146+
147+
@property
148+
def data(self):
149+
if self._data is None:
150+
self._data_file = h5py.File(self.split_filename, 'r')
151+
self._data = self._data_file['datasets']
152+
return self._data
153+
154+
@property
155+
def labels(self):
156+
if self._labels is None:
157+
with open(self.split_filename_labels, 'r') as f:
158+
self._labels = json.load(f)
159+
return self._labels
160+
161+
def _check_integrity(self):
162+
return (os.path.isfile(self.split_filename)
163+
and os.path.isfile(self.split_filename_labels))
164+
165+
def close(self):
166+
if self._data is not None:
167+
self._data.close()
168+
self._data = None
169+
170+
def download(self):
171+
172+
if self._check_integrity():
173+
return
174+
175+
from benchlib.datasets.syne_datasets import get_syne_dataset
176+
from benchlib.datasets.data_detergent import DataDetergent
177+
178+
# feature transforms are performed by the DataDetergent
179+
d = DataDetergent(get_syne_dataset(namespace=self.benchlib_namespace,
180+
dataset_name=self.benchlib_dataset_name + "/"),
181+
do_impute_nans=True,
182+
do_normalize_cols=True,
183+
do_remove_const_features=True)
184+
185+
# stack the features and targets into one big numpy array each, since we want a new split.
186+
features = []
187+
targets = []
188+
for split in ['train', 'val', 'test']:
189+
if split == 'train':
190+
data = d.get_training_data()
191+
elif split == 'val':
192+
data = d.get_validation_data()
193+
elif split == 'test':
194+
data = d.get_test_data()
195+
else:
196+
raise ValueError(f"split {split} not found.")
197+
features.append(data[0])
198+
targets.append(data[1])
199+
data = None
200+
features = np.concatenate(features, axis=0)
201+
targets = np.concatenate(targets, axis=0)
202+
203+
# for each meta-data-split, get the labels, then check which data-point belongs to the set (via a mask).
204+
# then, retrieve the features and targets belonging to the set. Then create hdf5 file for these features.
205+
for s, split in enumerate(['train', 'val', 'test']):
206+
label_set = get_asset(self.folder, '{0}.json'.format(split))
207+
label_set_integers = [int(l) for l in label_set]
208+
209+
is_in_set = [t in label_set_integers for t in targets]
210+
features_set = features[is_in_set, :]
211+
targets_set = targets[is_in_set]
212+
assert targets_set.shape[0] == features_set.shape[0]
213+
214+
unique_targets_set = np.sort(np.unique(targets_set))
215+
if len(label_set_integers) > unique_targets_set.shape[0]:
216+
print(f"unique set of labels is smaller ({len(unique_targets_set.shape[0])}) than set of labels "
217+
f"given by assets ({len(label_set_integers)}). Proceeding with unique set of labels.")
218+
219+
# write unique targets with enough data to json file. this is not necessarily the same as the tag set
220+
len_str = int(np.ceil(np.log10(unique_targets_set.shape[0] + 1)))
221+
unique_targets_str = [str(i).zfill(len_str) for i in unique_targets_set]
222+
223+
labels_filename = os.path.join(self.root, self.filename_labels.format(split))
224+
with open(labels_filename, 'w') as f:
225+
json.dump(unique_targets_str, f)
226+
227+
# write data (features and class labels)
228+
filename = os.path.join(self.root, self.filename.format(split))
229+
with h5py.File(filename, 'w') as f:
230+
group = f.create_group('datasets')
231+
dtype = h5py.special_dtype(vlen=np.float64)
232+
233+
for i, label in enumerate(tqdm(unique_targets_str, desc=filename)):
234+
data_class = features_set[targets_set == int(label), :]
235+
group.create_dataset(label, data=data_class) # , dtype=dtype)
236+
237+
238+
class LetterDataset(Dataset):
239+
def __init__(self, index, data, label, transform=None, target_transform=None):
240+
super(LetterDataset, self).__init__(index, transform=transform, target_transform=target_transform)
241+
self.data = data
242+
self.label = label
243+
244+
def __len__(self):
245+
return len(self.data)
246+
247+
def __getitem__(self, index):
248+
features = self.data[index, :]
249+
target = self.label
250+
251+
if self.transform is not None:
252+
features = self.transform(features)
253+
254+
if self.target_transform is not None:
255+
target = self.target_transform(target)
256+
257+
return features, target
258+
259+
260+
def create_asset(root='data', number_of_classes_per_split=None, numpy_seed=42):
261+
"""This methods creates the assets of the letter dataset. These are the meta-dataset splits from the
262+
original data. Only run this method in case you want to create new assets. Once created, copy the assets to
263+
this directory: torchmeta.datasets.assets.letter_task_id_6. You can also manually change the assets."""
264+
265+
# split fractions: train, valid, tes
266+
if number_of_classes_per_split is None:
267+
number_of_classes_per_split = {"train": 15,
268+
"val": 5,
269+
"test": 6}
270+
num_classes = 0
271+
for key in number_of_classes_per_split:
272+
num_classes += number_of_classes_per_split[key]
273+
assert num_classes == 26
274+
275+
def make_split(num_classes, number_of_classes_per_split):
276+
"""get permutation of labels and split according to number of classes per split"""
277+
np.random.seed(numpy_seed)
278+
279+
perm = np.random.permutation(num_classes)
280+
class_splits = {}
281+
start = 0
282+
for split in ["train", "val", "test"]:
283+
num_c = number_of_classes_per_split[split]
284+
285+
class_splits[split] = [str(i) for i in perm[start:start+num_c]]
286+
start += num_c
287+
return class_splits
288+
289+
# Split the classes according to the number of classes per split, and store the splits in the data directory.
290+
class_splits = make_split(num_classes, number_of_classes_per_split)
291+
print(class_splits)
292+
root_path = os.path.join(os.path.expanduser(root), LetterClassDataset.folder)
293+
for split in ["train", "val", "test"]:
294+
asset_filename = os.path.join(root_path, "{0}.json".format(split))
295+
with open(asset_filename, 'w') as f:
296+
json.dump(class_splits[split], f)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.