@@ -1124,3 +1124,83 @@ def test_exact_vmin():
1124
1124
@pytest .mark .flaky
1125
1125
def test_https_imread_smoketest ():
1126
1126
v = mimage .imread ('https://matplotlib.org/1.5.0/_static/logo2.png' )
1127
+
1128
+
1129
+ # A basic ndarray subclass that implements a quantity
1130
+ # It does not implement an entire unit system or all quantity math.
1131
+ # There is just enough implemented to test handling of ndarray
1132
+ # subclasses.
1133
+ class QuantityND (np .ndarray ):
1134
+ def __new__ (cls , input_array , units ):
1135
+ obj = np .asarray (input_array ).view (cls )
1136
+ obj .units = units
1137
+ return obj
1138
+
1139
+ def __array_finalize__ (self , obj ):
1140
+ self .units = getattr (obj , "units" , None )
1141
+
1142
+ def __getitem__ (self , item ):
1143
+ units = getattr (self , "units" , None )
1144
+ ret = super (QuantityND , self ).__getitem__ (item )
1145
+ if isinstance (ret , QuantityND ) or units is not None :
1146
+ ret = QuantityND (ret , units )
1147
+ return ret
1148
+
1149
+ def __array_ufunc__ (self , ufunc , method , * inputs , ** kwargs ):
1150
+ func = getattr (ufunc , method )
1151
+ if "out" in kwargs :
1152
+ raise NotImplementedError
1153
+ if len (inputs ) == 1 :
1154
+ i0 = inputs [0 ]
1155
+ unit = getattr (i0 , "units" , "dimensionless" )
1156
+ out_arr = func (np .asarray (i0 ), ** kwargs )
1157
+ elif len (inputs ) == 2 :
1158
+ i0 = inputs [0 ]
1159
+ i1 = inputs [1 ]
1160
+ u0 = getattr (i0 , "units" , "dimensionless" )
1161
+ u1 = getattr (i1 , "units" , "dimensionless" )
1162
+ u0 = u1 if u0 is None else u0
1163
+ u1 = u0 if u1 is None else u1
1164
+ if ufunc in [np .add , np .subtract ]:
1165
+ if u0 != u1 :
1166
+ raise ValueError
1167
+ unit = u0
1168
+ elif ufunc == np .multiply :
1169
+ unit = f"{ u0 } *{ u1 } "
1170
+ elif ufunc == np .divide :
1171
+ unit = f"{ u0 } /({ u1 } )"
1172
+ else :
1173
+ raise NotImplementedError
1174
+ out_arr = func (i0 .view (np .ndarray ), i1 .view (np .ndarray ), ** kwargs )
1175
+ else :
1176
+ raise NotImplementedError
1177
+ if unit is None :
1178
+ out_arr = np .array (out_arr )
1179
+ else :
1180
+ out_arr = QuantityND (out_arr , unit )
1181
+ return out_arr
1182
+
1183
+ @property
1184
+ def v (self ):
1185
+ return self .view (np .ndarray )
1186
+
1187
+
1188
+ def test_quantitynd ():
1189
+ q = QuantityND ([1 , 2 ], "m" )
1190
+ q0 , q1 = q [:]
1191
+ assert np .all (q .v == np .asarray ([1 , 2 ]))
1192
+ assert q .units == "m"
1193
+ assert np .all ((q0 + q1 ).v == np .asarray ([3 ]))
1194
+ assert (q0 * q1 ).units == "m*m"
1195
+ assert (q1 / q0 ).units == "m/(m)"
1196
+ with pytest .raises (ValueError ):
1197
+ q0 + QuantityND (1 , "s" )
1198
+
1199
+
1200
+ def test_imshow_quantitynd ():
1201
+ # generate a dummy ndarray subclass
1202
+ arr = QuantityND (np .ones ((2 , 2 )), "m" )
1203
+ fig , ax = plt .subplots ()
1204
+ ax .imshow (arr )
1205
+ # executing the draw should not raise an exception
1206
+ fig .canvas .draw ()
0 commit comments