import matplotlib.cbook as cbook import matplotlib.pyplot as plt import matplotlib.axes as maxes import matplotlib.colorbar as mcolorbar import matplotlib as mpl import matplotlib.patches as mpatches import matplotlib.lines as mlines import matplotlib.ticker as ticker from axes_divider import Size, SubplotDivider, LocatableAxes, Divider def _tick_only(ax, bottom_on, left_on): bottom_off = not bottom_on left_off = not left_on [l.set_visible(bottom_off) for l in ax.get_xticklabels()] [l.set_visible(left_off) for l in ax.get_yticklabels()] ax.xaxis.label.set_visible(bottom_off) ax.yaxis.label.set_visible(left_off) if hasattr(ax, "_axislines"): ax._axislines["bottom"].major_ticklabels.set_visible(bottom_off) ax._axislines["left"].major_ticklabels.set_visible(left_off) ax._axislines["bottom"].minor_ticklabels.set_visible(bottom_off) ax._axislines["left"].minor_ticklabels.set_visible(left_off) ax.axis["bottom"].label.set_visible(bottom_off) ax.axis["left"].label.set_visible(left_off) class Colorbar(mcolorbar.Colorbar): def _config_axes(self, X, Y): ''' Make an axes patch and outline. ''' ax = self.ax ax.set_frame_on(False) ax.set_navigate(False) xy = self._outline(X, Y) ax.update_datalim(xy) ax.set_xlim(*ax.dataLim.intervalx) ax.set_ylim(*ax.dataLim.intervaly) self.outline = mlines.Line2D(xy[:, 0], xy[:, 1], color=mpl.rcParams['axes.edgecolor'], linewidth=mpl.rcParams['axes.linewidth']) ax.add_artist(self.outline) self.outline.set_clip_box(None) self.outline.set_clip_path(None) c = mpl.rcParams['axes.facecolor'] self.patch = mpatches.Polygon(xy, edgecolor=c, facecolor=c, linewidth=0.01, zorder=-1) ax.add_artist(self.patch) ticks, ticklabels, offset_string = self._ticker() if self.orientation == 'vertical': ax.set_yticks(ticks) ax.set_yticklabels(ticklabels) ax.yaxis.get_major_formatter().set_offset_string(offset_string) else: ax.set_xticks(ticks) ax.set_xticklabels(ticklabels) ax.xaxis.get_major_formatter().set_offset_string(offset_string) class CbarAxes(LocatableAxes): def __init__(self, *kl, **kwargs): orientation=kwargs.pop("orientation", None) if orientation is None: raise ValueError("orientation must be specified") self.orientation = orientation self._default_label_on = False self.locator = None super(CbarAxes, self).__init__(*kl, **kwargs) def colorbar(self, mappable, **kwargs): locator=kwargs.pop("locator", None) if locator is None: locator = ticker.MaxNLocator(5) self.locator = locator kwargs["ticks"] = locator self.hold(True) if self.orientation in ["top", "bottom"]: orientation="horizontal" else: orientation="vertical" cb = Colorbar(self, mappable, orientation=orientation, **kwargs) #self._config_axes() def on_changed(m): #print 'calling on changed', m.get_cmap().name cb.set_cmap(m.get_cmap()) cb.set_clim(m.get_clim()) cb.update_bruteforce(m) self.cbid = mappable.callbacksSM.connect('changed', on_changed) mappable.set_colorbar(cb, self) return cb def cla(self): super(CbarAxes, self).cla() self._config_axes() def _config_axes(self): ''' Make an axes patch and outline. ''' ax = self ax.set_navigate(False) for axis in ax.axis.values(): axis.major_ticks.set_visible(False) axis.minor_ticks.set_visible(False) axis.major_ticklabels.set_visible(False) axis.minor_ticklabels.set_visible(False) axis.label.set_visible(False) axis = ax.axis[self.orientation] axis.major_ticks.set_visible(True) axis.minor_ticks.set_visible(True) axis.major_ticklabels.set_size(int(axis.major_ticklabels.get_size()*.9)) axis.major_tick_pad = 3 b = self._default_label_on axis.major_ticklabels.set_visible(b) axis.minor_ticklabels.set_visible(b) axis.label.set_visible(b) def toggle_label(self, b): self._default_label_on = b axis = self.axis[self.orientation] axis.major_ticklabels.set_visible(b) axis.minor_ticklabels.set_visible(b) axis.label.set_visible(b) class Grid(object): """ A class that creates a grid of Axes. In matplotlib, the axes location (and size) is specified in the normalized figure coordinates. This may not be ideal for images that needs to be displayed with a given aspect ratio. For example, displaying images of a same size with some fixed padding between them cannot be easily done in matplotlib. AxesGrid is used in such case. """ def __init__(self, fig, rect, nrows_ncols, ngrids = None, direction="row", axes_pad = 0.02, add_all=True, share_all=False, share_x=True, share_y=True, #aspect=True, label_mode="L", axes_class=None, ): """ Build an :class:`Grid` instance with a grid nrows*ncols :class:`~matplotlib.axes.Axes` in :class:`~matplotlib.figure.Figure` *fig* with *rect=[left, bottom, width, height]* (in :class:`~matplotlib.figure.Figure` coordinates) or the subplot position code (e.g., "121"). Optional keyword arguments: ================ ======== ========================================= Keyword Default Description ================ ======== ========================================= direction "row" [ "row" | "column" ] axes_pad 0.02 float| pad betweein axes given in inches add_all True [ True | False ] share_all False [ True | False ] share_x True [ True | False ] share_y True [ True | False ] label_mode "L" [ "L" | "1" | "all" ] axes_class None a type object which must be a subclass of :class:`~matplotlib.axes.Axes` ================ ======== ========================================= """ self._nrows, self._ncols = nrows_ncols if ngrids is None: ngrids = self._nrows * self._ncols else: if (ngrids > self._nrows * self._ncols) or (ngrids <= 0): raise Exception("") self.ngrids = ngrids self._init_axes_pad(axes_pad) if direction not in ["column", "row"]: raise Exception("") self._direction = direction if axes_class is None: axes_class = LocatableAxes axes_class_args = {} else: if isinstance(axes_class, maxes.Axes): axes_class_args = {} else: axes_class, axes_class_args = axes_class self.axes_all = [] self.axes_column = [[] for i in range(self._ncols)] self.axes_row = [[] for i in range(self._nrows)] h = [] v = [] if cbook.is_string_like(rect) or cbook.is_numlike(rect): self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v, aspect=False) elif len(rect) == 3: kw = dict(horizontal=h, vertical=v, aspect=False) self._divider = SubplotDivider(fig, *rect, **kw) elif len(rect) == 4: self._divider = Divider(fig, rect, horizontal=h, vertical=v, aspect=False) else: raise Exception("") rect = self._divider.get_position() # reference axes self._column_refax = [None for i in range(self._ncols)] self._row_refax = [None for i in range(self._nrows)] self._refax = None for i in range(self.ngrids): col, row = self._get_col_row(i) if share_all: sharex = self._refax sharey = self._refax else: if share_x: sharex = self._column_refax[col] else: sharex = None if share_y: sharey = self._row_refax[row] else: sharey = None ax = axes_class(fig, rect, sharex=sharex, sharey=sharey, **axes_class_args) if share_all: if self._refax is None: self._refax = ax else: if sharex is None: self._column_refax[col] = ax if sharey is None: self._row_refax[row] = ax self.axes_all.append(ax) self.axes_column[col].append(ax) self.axes_row[row].append(ax) self.axes_llc = self.axes_column[0][-1] self._update_locators() if add_all: for ax in self.axes_all: fig.add_axes(ax) self.set_label_mode(label_mode) def _init_axes_pad(self, axes_pad): self._axes_pad = axes_pad self._horiz_pad_size = Size.Fixed(axes_pad) self._vert_pad_size = Size.Fixed(axes_pad) def _update_locators(self): h = [] h_ax_pos = [] h_cb_pos = [] for ax in self._column_refax: #if h: h.append(Size.Fixed(self._axes_pad)) if h: h.append(self._horiz_pad_size) h_ax_pos.append(len(h)) sz = Size.Scaled(1) h.append(sz) v = [] v_ax_pos = [] v_cb_pos = [] for ax in self._row_refax[::-1]: #if v: v.append(Size.Fixed(self._axes_pad)) if v: v.append(self._vert_pad_size) v_ax_pos.append(len(v)) sz = Size.Scaled(1) v.append(sz) for i in range(self.ngrids): col, row = self._get_col_row(i) locator = self._divider.new_locator(nx=h_ax_pos[col], ny=v_ax_pos[self._nrows -1 - row]) self.axes_all[i].set_axes_locator(locator) self._divider.set_horizontal(h) self._divider.set_vertical(v) def _get_col_row(self, n): if self._direction == "column": col, row = divmod(n, self._nrows) else: row, col = divmod(n, self._ncols) return col, row def __getitem__(self, i): return self.axes_all[i] def get_geometry(self): """ get geometry of the grid. Returns a tuple of two integer, representing number of rows and number of columns. """ return self._nrows, self._ncols def set_axes_pad(self, axes_pad): "set axes_pad" self._axes_pad = axes_pad self._horiz_pad_size.fixed_size = axes_pad self._vert_pad_size.fixed_size = axes_pad def get_axes_pad(self): "get axes_pad" return self._axes_pad def set_aspect(self, aspect): "set aspect" self._divider.set_aspect(aspect) def get_aspect(self): "get aspect" return self._divider.get_aspect() def set_label_mode(self, mode): "set label_mode" if mode == "all": for ax in self.axes_all: _tick_only(ax, False, False) elif mode == "L": # left-most axes for ax in self.axes_column[0][:-1]: _tick_only(ax, bottom_on=True, left_on=False) # lower-left axes ax = self.axes_column[0][-1] _tick_only(ax, bottom_on=False, left_on=False) for col in self.axes_column[1:]: # axes with no labels for ax in col[:-1]: _tick_only(ax, bottom_on=True, left_on=True) # bottom ax = col[-1] _tick_only(ax, bottom_on=False, left_on=True) elif mode == "1": for ax in self.axes_all: _tick_only(ax, bottom_on=True, left_on=True) ax = self.axes_llc _tick_only(ax, bottom_on=False, left_on=False) class ImageGrid(Grid): """ A class that creates a grid of Axes. In matplotlib, the axes location (and size) is specified in the normalized figure coordinates. This may not be ideal for images that needs to be displayed with a given aspect ratio. For example, displaying images of a same size with some fixed padding between them cannot be easily done in matplotlib. ImageGrid is used in such case. """ def __init__(self, fig, rect, nrows_ncols, ngrids = None, direction="row", axes_pad = 0.02, add_all=True, share_all=False, aspect=True, label_mode="L", cbar_mode=None, cbar_location="right", cbar_pad=None, cbar_size="5%", axes_class=None, ): """ Build an :class:`ImageGrid` instance with a grid nrows*ncols :class:`~matplotlib.axes.Axes` in :class:`~matplotlib.figure.Figure` *fig* with *rect=[left, bottom, width, height]* (in :class:`~matplotlib.figure.Figure` coordinates) or the subplot position code (e.g., "121"). Optional keyword arguments: ================ ======== ========================================= Keyword Default Description ================ ======== ========================================= direction "row" [ "row" | "column" ] axes_pad 0.02 float| pad betweein axes given in inches add_all True [ True | False ] share_all False [ True | False ] aspect True [ True | False ] label_mode "L" [ "L" | "1" | "all" ] cbar_mode None [ "each" | "single" ] cbar_location "right" [ "right" | "top" ] cbar_pad None cbar_size "5%" axes_class None a type object which must be a subclass of :class:`~matplotlib.axes.Axes` ================ ======== ========================================= """ self._nrows, self._ncols = nrows_ncols if ngrids is None: ngrids = self._nrows * self._ncols else: if (ngrids > self._nrows * self._ncols) or (ngrids <= 0): raise Exception("") self.ngrids = ngrids self._axes_pad = axes_pad self._colorbar_mode = cbar_mode self._colorbar_location = cbar_location if cbar_pad is None: self._colorbar_pad = axes_pad else: self._colorbar_pad = cbar_pad self._colorbar_size = cbar_size self._init_axes_pad(axes_pad) if direction not in ["column", "row"]: raise Exception("") self._direction = direction if axes_class is None: axes_class = LocatableAxes axes_class_args = {} else: if isinstance(axes_class, maxes.Axes): axes_class_args = {} else: axes_class, axes_class_args = axes_class cbar_axes_class = CbarAxes self.axes_all = [] self.axes_column = [[] for i in range(self._ncols)] self.axes_row = [[] for i in range(self._nrows)] self.cbar_axes = [] h = [] v = [] if cbook.is_string_like(rect) or cbook.is_numlike(rect): self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v, aspect=aspect) elif len(rect) == 3: kw = dict(horizontal=h, vertical=v, aspect=aspect) self._divider = SubplotDivider(fig, *rect, **kw) elif len(rect) == 4: self._divider = Divider(fig, rect, horizontal=h, vertical=v, aspect=aspect) else: raise Exception("") rect = self._divider.get_position() # reference axes self._column_refax = [None for i in range(self._ncols)] self._row_refax = [None for i in range(self._nrows)] self._refax = None for i in range(self.ngrids): col, row = self._get_col_row(i) if share_all: sharex = self._refax sharey = self._refax else: sharex = self._column_refax[col] sharey = self._row_refax[row] ax = axes_class(fig, rect, sharex=sharex, sharey=sharey, **axes_class_args) if share_all: if self._refax is None: self._refax = ax else: if sharex is None: self._column_refax[col] = ax if sharey is None: self._row_refax[row] = ax self.axes_all.append(ax) self.axes_column[col].append(ax) self.axes_row[row].append(ax) cax = cbar_axes_class(fig, rect, orientation=self._colorbar_location) self.cbar_axes.append(cax) self.axes_llc = self.axes_column[0][-1] self._update_locators() if add_all: for ax in self.axes_all+self.cbar_axes: fig.add_axes(ax) self.set_label_mode(label_mode) def _update_locators(self): h = [] h_ax_pos = [] h_cb_pos = [] for ax in self._column_refax: if h: h.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad)) h_ax_pos.append(len(h)) if ax: sz = Size.AxesX(ax) else: sz = Size.AxesX(self.axes_llc) h.append(sz) if self._colorbar_mode == "each" and self._colorbar_location == "right": h.append(Size.from_any(self._colorbar_pad, sz)) h_cb_pos.append(len(h)) h.append(Size.from_any(self._colorbar_size, sz)) v = [] v_ax_pos = [] v_cb_pos = [] for ax in self._row_refax[::-1]: if v: v.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad)) v_ax_pos.append(len(v)) if ax: sz = Size.AxesY(ax) else: sz = Size.AxesY(self.axes_llc) v.append(sz) if self._colorbar_mode == "each" and self._colorbar_location == "top": v.append(Size.from_any(self._colorbar_pad, sz)) v_cb_pos.append(len(v)) v.append(Size.from_any(self._colorbar_size, sz)) for i in range(self.ngrids): col, row = self._get_col_row(i) #locator = self._divider.new_locator(nx=4*col, ny=2*(self._nrows - row - 1)) locator = self._divider.new_locator(nx=h_ax_pos[col], ny=v_ax_pos[self._nrows -1 - row]) self.axes_all[i].set_axes_locator(locator) if self._colorbar_mode == "each": if self._colorbar_location == "right": locator = self._divider.new_locator(nx=h_cb_pos[col], ny=v_ax_pos[self._nrows -1 - row]) elif self._colorbar_location == "top": locator = self._divider.new_locator(nx=h_ax_pos[col], ny=v_cb_pos[self._nrows -1 - row]) self.cbar_axes[i].set_axes_locator(locator) if self._colorbar_mode == "single": if self._colorbar_location == "right": #sz = Size.Fraction(Size.AxesX(self.axes_llc), self._nrows) sz = Size.Fraction(self._nrows, Size.AxesX(self.axes_llc)) h.append(Size.from_any(self._colorbar_pad, sz)) h.append(Size.from_any(self._colorbar_size, sz)) locator = self._divider.new_locator(nx=-2, ny=0, ny1=-1) elif self._colorbar_location == "top": #sz = Size.Fraction(Size.AxesY(self.axes_llc), self._ncols) sz = Size.Fraction(self._ncols, Size.AxesY(self.axes_llc)) v.append(Size.from_any(self._colorbar_pad, sz)) v.append(Size.from_any(self._colorbar_size, sz)) locator = self._divider.new_locator(nx=0, nx1=-1, ny=-2) for i in range(self.ngrids): self.cbar_axes[i].set_visible(False) self.cbar_axes[0].set_axes_locator(locator) self.cbar_axes[0].set_visible(True) elif self._colorbar_mode == "each": for i in range(self.ngrids): self.cbar_axes[i].set_visible(True) else: for i in range(self.ngrids): self.cbar_axes[i].set_visible(False) self.cbar_axes[i].set_position([1., 1., 0.001, 0.001], which="active") self._divider.set_horizontal(h) self._divider.set_vertical(v) AxesGrid = ImageGrid #if __name__ == "__main__": if 0: F = plt.figure(1, (7, 6)) F.clf() F.subplots_adjust(left=0.15, right=0.9) grid = Grid(F, 111, # similar to subplot(111) nrows_ncols = (2, 2), direction="row", axes_pad = 0.05, add_all=True, label_mode = "1", ) if __name__ == "__main__": #if 0: from axes_divider import get_demo_image F = plt.figure(1, (9, 3.5)) F.clf() F.subplots_adjust(left=0.05, right=0.98) grid = ImageGrid(F, 131, # similar to subplot(111) nrows_ncols = (2, 2), direction="row", axes_pad = 0.05, add_all=True, label_mode = "1", ) Z, extent = get_demo_image() plt.ioff() for i in range(4): im = grid[i].imshow(Z, extent=extent, interpolation="nearest") # This only affects axes in first column and second row as share_all = False. grid.axes_llc.set_xticks([-2, 0, 2]) grid.axes_llc.set_yticks([-2, 0, 2]) plt.ion() grid = ImageGrid(F, 132, # similar to subplot(111) nrows_ncols = (2, 2), direction="row", axes_pad = 0.0, add_all=True, share_all=True, label_mode = "1", cbar_mode="single", ) Z, extent = get_demo_image() plt.ioff() for i in range(4): im = grid[i].imshow(Z, extent=extent, interpolation="nearest") plt.colorbar(im, cax = grid.cbar_axes[0]) plt.setp(grid.cbar_axes[0].get_yticklabels(), visible=False) # This affects all axes as share_all = True. grid.axes_llc.set_xticks([-2, 0, 2]) grid.axes_llc.set_yticks([-2, 0, 2]) plt.ion() grid = ImageGrid(F, 133, # similar to subplot(122) nrows_ncols = (2, 2), direction="row", axes_pad = 0.1, add_all=True, label_mode = "1", share_all = True, cbar_location="top", cbar_mode="each", cbar_size="7%", cbar_pad="2%", ) plt.ioff() for i in range(4): im = grid[i].imshow(Z, extent=extent, interpolation="nearest") plt.colorbar(im, cax = grid.cbar_axes[i], orientation="horizontal") grid.cbar_axes[i].xaxis.set_ticks_position("top") plt.setp(grid.cbar_axes[i].get_xticklabels(), visible=False) # This affects all axes as share_all = True. grid.axes_llc.set_xticks([-2, 0, 2]) grid.axes_llc.set_yticks([-2, 0, 2]) plt.ion() plt.draw()