Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 624991c

Browse filesBrowse files
committed
CategoryNorm
1 parent c6eec14 commit 624991c
Copy full SHA for 624991c

File tree

Expand file treeCollapse file tree

2 files changed

+113
-14
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

2 files changed

+113
-14
lines changed
Open diff view settings
Collapse file

‎lib/matplotlib/category.py‎

Copy file name to clipboardExpand all lines: lib/matplotlib/category.py
+99-14Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
import numpy as np
1111

12+
import matplotlib.colors as mcolors
1213
import matplotlib.cbook as cbook
13-
import matplotlib.units as units
14-
import matplotlib.ticker as ticker
15-
14+
import matplotlib.units as munits
15+
import matplotlib.ticker as mticker
1616

1717
# pure hack for numpy 1.6 support
1818
from distutils.version import LooseVersion
@@ -33,11 +33,22 @@ def to_array(data, maxlen=100):
3333
return vals
3434

3535

36-
class StrCategoryConverter(units.ConversionInterface):
36+
class StrCategoryConverter(munits.ConversionInterface):
37+
"""Converts categorical (or string) data to numerical values
38+
39+
Conversion typically happens in the following order:
40+
1. default_units:
41+
creates unit_data category-integer mapping and binds to axis
42+
2. axis_info:
43+
sets ticks/locator and label/formatter
44+
3. convert:
45+
maps input category data to integers using unit_data
46+
47+
"""
3748
@staticmethod
3849
def convert(value, unit, axis):
39-
"""Uses axis.unit_data map to encode
40-
data as floats
50+
"""
51+
Encode value as floats using axis.unit_data
4152
"""
4253
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))
4354

@@ -52,33 +63,107 @@ def convert(value, unit, axis):
5263

5364
@staticmethod
5465
def axisinfo(unit, axis):
66+
"""
67+
Return the :class:`~matplotlib.units.AxisInfo` for *unit*.
68+
69+
*unit* is None
70+
*axis.unit_data* is used to set ticks and labels
71+
"""
5572
majloc = StrCategoryLocator(axis.unit_data.locs)
5673
majfmt = StrCategoryFormatter(axis.unit_data.seq)
57-
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
74+
return munits.AxisInfo(majloc=majloc, majfmt=majfmt)
5875

5976
@staticmethod
6077
def default_units(data, axis):
61-
# the conversion call stack is:
62-
# default_units->axis_info->convert
78+
"""
79+
Create mapping between string categories in *data*
80+
and integers, then store in *axis.unit_data*
81+
"""
6382
if axis.unit_data is None:
6483
axis.unit_data = UnitData(data)
6584
else:
6685
axis.unit_data.update(data)
6786
return None
6887

6988

70-
class StrCategoryLocator(ticker.FixedLocator):
89+
class StrCategoryLocator(mticker.FixedLocator):
90+
"""
91+
Ensures that every category has a tick by subclassing
92+
:class:`~matplotlib.ticker.FixedLocator`
93+
"""
7194
def __init__(self, locs):
7295
self.locs = locs
7396
self.nbins = None
7497

7598

76-
class StrCategoryFormatter(ticker.FixedFormatter):
99+
class StrCategoryFormatter(mticker.FixedFormatter):
100+
"""
101+
Labels every category by subclassing
102+
:class:`~matplotlib.ticker.FixedFormatter`
103+
"""
77104
def __init__(self, seq):
78105
self.seq = seq
79106
self.offset_string = ''
80107

81108

109+
class CategoryNorm(mcolors.Normalize):
110+
"""
111+
Preserves ordering of discrete values
112+
"""
113+
def __init__(self, categories):
114+
"""
115+
*categories*
116+
distinct values for mapping
117+
118+
Out-of-range values are mapped to a value not in categories;
119+
these are then converted to valid indices by :meth:`Colormap.__call__`.
120+
"""
121+
self.categories = categories
122+
self.N = len(self.categories)
123+
self.vmin = 0
124+
self.vmax = self.N
125+
self._interp = False
126+
127+
def __call__(self, value, clip=None):
128+
if not cbook.iterable(value):
129+
value = [value]
130+
131+
value = np.asarray(value)
132+
ret = np.ones(value.shape) * np.nan
133+
134+
for i, c in enumerate(self.categories):
135+
ret[value == c] = i / (self.N * 1.0)
136+
137+
return np.ma.array(ret, mask=np.isnan(ret))
138+
139+
def inverse(self, value):
140+
# not quite sure what invertible means in this context
141+
return ValueError("CategoryNorm is not invertible")
142+
143+
144+
def colors_from_categories(codings):
145+
"""
146+
Helper routine to generate a cmap and a norm from a list
147+
of (color, value) pairs
148+
149+
Parameters
150+
----------
151+
codings : sequence of (key, value) pairs
152+
153+
Returns
154+
-------
155+
(cmap, norm) : tuple containing a :class:`Colormap` and a \
156+
:class:`Normalize` instance
157+
"""
158+
if isinstance(codings, dict):
159+
codings = codings.items()
160+
161+
values, colors = zip(*codings)
162+
cmap = mcolors.ListedColormap(list(colors))
163+
norm = CategoryNorm(list(values))
164+
return cmap, norm
165+
166+
82167
def convert_to_string(value):
83168
"""Helper function for numpy 1.6, can be replaced with
84169
np.array(...,dtype=unicode) for all later versions of numpy"""
@@ -132,6 +217,6 @@ def _set_seq_locs(self, data, value):
132217
value += 1
133218

134219
# Connects the convertor to matplotlib
135-
units.registry[str] = StrCategoryConverter()
136-
units.registry[bytes] = StrCategoryConverter()
137-
units.registry[six.text_type] = StrCategoryConverter()
220+
munits.registry[str] = StrCategoryConverter()
221+
munits.registry[bytes] = StrCategoryConverter()
222+
munits.registry[six.text_type] = StrCategoryConverter()
Collapse file

‎lib/matplotlib/tests/test_category.py‎

Copy file name to clipboardExpand all lines: lib/matplotlib/tests/test_category.py
+14Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ def test_StrCategoryFormatterUnicode(self):
128128
assert labels('a', 1) == "привет"
129129

130130

131+
class TestCategoryNorm(object):
132+
testdata = [[[205, 302, 205, 101], [0, 2. / 3., 0, 1. / 3.]],
133+
[[205, np.nan, 101, 305], [0, 9999, 1. / 3., 2. / 3.]],
134+
[[205, 101, 504, 101], [0, 9999, 1. / 3., 1. / 3.]]]
135+
136+
ids = ["regular", "nan", "exclude"]
137+
138+
@pytest.mark.parametrize("data, nmap", testdata, ids=ids)
139+
def test_norm(self, data, nmap):
140+
norm = cat.CategoryNorm([205, 101, 302])
141+
test = np.ma.masked_equal(nmap, 9999)
142+
np.testing.assert_allclose(norm(data), test)
143+
144+
131145
def lt(tl):
132146
return [l.get_text() for l in tl]
133147

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.