39
39
40
40
from matplotlib .axes import Axes , SubplotBase , subplot_class_factory
41
41
from matplotlib .blocking_input import BlockingMouseInput , BlockingKeyMouseInput
42
+ from matplotlib .gridspec import GridSpec
42
43
from matplotlib .legend import Legend
43
44
from matplotlib .patches import Rectangle
44
45
from matplotlib .projections import (get_projection_names ,
@@ -1001,6 +1002,142 @@ def add_subplot(self, *args, **kwargs):
1001
1002
self .stale = True
1002
1003
return a
1003
1004
1005
+ def add_subplots (self , nrows = 1 , ncols = 1 , sharex = False , sharey = False ,
1006
+ squeeze = True , subplot_kw = None , gridspec_kw = None ):
1007
+ """
1008
+ Add a set of subplots to this figure.
1009
+
1010
+ Keyword arguments:
1011
+
1012
+ *nrows* : int
1013
+ Number of rows of the subplot grid. Defaults to 1.
1014
+
1015
+ *ncols* : int
1016
+ Number of columns of the subplot grid. Defaults to 1.
1017
+
1018
+ *sharex* : string or bool
1019
+ If *True*, the X axis will be shared amongst all subplots. If
1020
+ *True* and you have multiple rows, the x tick labels on all but
1021
+ the last row of plots will have visible set to *False*
1022
+ If a string must be one of "row", "col", "all", or "none".
1023
+ "all" has the same effect as *True*, "none" has the same effect
1024
+ as *False*.
1025
+ If "row", each subplot row will share a X axis.
1026
+ If "col", each subplot column will share a X axis and the x tick
1027
+ labels on all but the last row will have visible set to *False*.
1028
+
1029
+ *sharey* : string or bool
1030
+ If *True*, the Y axis will be shared amongst all subplots. If
1031
+ *True* and you have multiple columns, the y tick labels on all but
1032
+ the first column of plots will have visible set to *False*
1033
+ If a string must be one of "row", "col", "all", or "none".
1034
+ "all" has the same effect as *True*, "none" has the same effect
1035
+ as *False*.
1036
+ If "row", each subplot row will share a Y axis and the y tick
1037
+ labels on all but the first column will have visible set to *False*.
1038
+ If "col", each subplot column will share a Y axis.
1039
+
1040
+ *squeeze* : bool
1041
+ If *True*, extra dimensions are squeezed out from the
1042
+ returned axis object:
1043
+
1044
+ - if only one subplot is constructed (nrows=ncols=1), the
1045
+ resulting single Axis object is returned as a scalar.
1046
+
1047
+ - for Nx1 or 1xN subplots, the returned object is a 1-d numpy
1048
+ object array of Axis objects are returned as numpy 1-d
1049
+ arrays.
1050
+
1051
+ - for NxM subplots with N>1 and M>1 are returned as a 2d
1052
+ array.
1053
+
1054
+ If *False*, no squeezing at all is done: the returned axis
1055
+ object is always a 2-d array containing Axis instances, even if it
1056
+ ends up being 1x1.
1057
+
1058
+ *subplot_kw* : dict
1059
+ Dict with keywords passed to the
1060
+ :meth:`~matplotlib.figure.Figure.add_subplot` call used to
1061
+ create each subplots.
1062
+
1063
+ *gridspec_kw* : dict
1064
+ Dict with keywords passed to the
1065
+ :class:`~matplotlib.gridspec.GridSpec` constructor used to create
1066
+ the grid the subplots are placed on.
1067
+
1068
+ Returns:
1069
+
1070
+ ax : single axes object or array of axes objects
1071
+ The addes axes. The dimensions of the resulting array can be
1072
+ controlled with the squeeze keyword, see above.
1073
+
1074
+ See the docstring of :func:`~pyplot.subplots' for examples
1075
+ """
1076
+
1077
+ # for backwards compatibility
1078
+ if isinstance (sharex , bool ):
1079
+ sharex = "all" if sharex else "none"
1080
+ if isinstance (sharey , bool ):
1081
+ sharey = "all" if sharey else "none"
1082
+ share_values = ["all" , "row" , "col" , "none" ]
1083
+ if sharex not in share_values :
1084
+ # This check was added because it is very easy to type
1085
+ # `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended.
1086
+ # In most cases, no error will ever occur, but mysterious behavior
1087
+ # will result because what was intended to be the subplot index is
1088
+ # instead treated as a bool for sharex.
1089
+ if isinstance (sharex , int ):
1090
+ warnings .warn (
1091
+ "sharex argument to add_subplots() was an integer. "
1092
+ "Did you intend to use add_subplot() (without 's')?" )
1093
+
1094
+ raise ValueError ("sharex [%s] must be one of %s" %
1095
+ (sharex , share_values ))
1096
+ if sharey not in share_values :
1097
+ raise ValueError ("sharey [%s] must be one of %s" %
1098
+ (sharey , share_values ))
1099
+ if subplot_kw is None :
1100
+ subplot_kw = {}
1101
+ if gridspec_kw is None :
1102
+ gridspec_kw = {}
1103
+
1104
+ gs = GridSpec (nrows , ncols , ** gridspec_kw )
1105
+
1106
+ # Create array to hold all axes.
1107
+ axarr = np .empty ((nrows , ncols ), dtype = object )
1108
+ for row in range (nrows ):
1109
+ for col in range (ncols ):
1110
+ shared_with = {"none" : None , "all" : axarr [0 , 0 ],
1111
+ "row" : axarr [row , 0 ], "col" : axarr [0 , col ]}
1112
+ subplot_kw ["sharex" ] = shared_with [sharex ]
1113
+ subplot_kw ["sharey" ] = shared_with [sharey ]
1114
+ axarr [row , col ] = self .add_subplot (gs [row , col ], ** subplot_kw )
1115
+
1116
+ # turn off redundant tick labeling
1117
+ if sharex in ["col" , "all" ] and nrows > 1 :
1118
+ # turn off all but the bottom row
1119
+ for ax in axarr [:- 1 , :].flat :
1120
+ for label in ax .get_xticklabels ():
1121
+ label .set_visible (False )
1122
+ ax .xaxis .offsetText .set_visible (False )
1123
+
1124
+ if sharey in ["row" , "all" ] and ncols > 1 :
1125
+ # turn off all but the first column
1126
+ for ax in axarr [:, 1 :].flat :
1127
+ for label in ax .get_yticklabels ():
1128
+ label .set_visible (False )
1129
+ ax .yaxis .offsetText .set_visible (False )
1130
+
1131
+ if squeeze :
1132
+ # Reshape the array to have the final desired dimension (nrow,ncol),
1133
+ # though discarding unneeded dimensions that equal 1. If we only have
1134
+ # one subplot, just return it instead of a 1-element array.
1135
+ return axarr .item () if axarr .size == 1 else axarr .squeeze ()
1136
+ else :
1137
+ # returned axis array will be always 2-d, even if nrows=ncols=1
1138
+ return axarr
1139
+
1140
+
1004
1141
def clf (self , keep_observers = False ):
1005
1142
"""
1006
1143
Clear the figure.
0 commit comments