Source code for vsi.tools.mpl


import math

try:
  import rasterio
except:
  pass

import numpy as np
import matplotlib as mpl
from matplotlib import pylab as plt


[docs]def share_xy_custom(ax1, ax2, forward=lambda x: x, backward=lambda x: x, sharex=True, sharey=True, **kwargs): ''' Like ``Axes.sharex``/``Axes.sharey`` only it allows you to provide your own function instead of the 1:1 default. Parameters ---------- ax1 : :class:`matplotlib.axes.Axes` The first Axes ax2 : :class:`matplotlib.axes.Axes` The second Axes forward : Callable A function that takes the arguments ``x0``, ``x1``, ``y0``, ``y1`` for ``ax1`` and returns ``(x0, x1, y0, y1)`` for ``ax2``. Return ``None`` to indicate that no update to ``ax2`` should be performed. backward : Callable Same as ``forward`` except maps ``ax2`` to ``ax1`` sharex : :class:`bool`, optional Should the x axis be shared sharey : :class:`bool`, optional Should the y axis be shared **kwargs Additional keyword arguments passed to ``forward`` and ``backward`` ''' def wrap_call(func, ax, dest_ax): # Unpack the bbox x0 = ax.viewLim.x0 x1 = ax.viewLim.x1 y0 = ax.viewLim.y0 y1 = ax.viewLim.y1 # Call the user specified function bbox_new = func(x0, x1, y0, y1, **kwargs) # If it's not None or empty if bbox_new: # Unpack the bbox x0_new, x1_new, y0_new, y1_new = bbox_new # And set the lims if sharex: if x0 > x1: # Handle x flip dest_ax.set_xlim(max(x0_new, x1_new), min(x0_new, x1_new), emit=False) else: dest_ax.set_xlim(min(x0_new, x1_new), max(x0_new, x1_new), emit=False) if sharey: if y0 > y1: # Handle y flip dest_ax.set_ylim(max(y0_new, y1_new), min(y0_new, y1_new), emit=False) else: dest_ax.set_ylim(min(y0_new, y1_new), max(y0_new, y1_new), emit=False) _forward = lambda ax: wrap_call(forward, ax, ax2) _backward = lambda ax: wrap_call(backward, ax, ax1) # Set up event listeners if sharex: ax1.callbacks.connect('xlim_changed', _forward) ax2.callbacks.connect('xlim_changed', _backward) if sharey: ax1.callbacks.connect('ylim_changed', _forward) ax2.callbacks.connect('ylim_changed', _backward)
[docs]def imshow_chip(axes, img, origin, size, *args, **kwargs): ''' imshow's only a chip of an image, and draws the chip at the appropriate coordinates. Does not support custom transforms Parameters ---------- axes : :class:`matplotlib.axes.Axes` The first Axes img : :class:`numpy.ndarray` Image array origin : tuple The origin of the chip to load, in y, x, order size : The size of the chip to load, in y, x, order *args : Additional parameters passed to :func:`matplotlib.pyplot.imshow` **kwargs Additional keyword arguments passed to :func:`matplotlib.pyplot.imshow` ''' img_chip = img[origin[0]:(origin[0]+size[0]), origin[1]:(origin[1]+size[1])] axes.imshow(img_chip, extent=(origin[1]-0.5, origin[1]+img_chip.shape[1]-0.5, origin[0]+img_chip.shape[0]-0.5, origin[0]-0.5), *args, **kwargs)
[docs]def imshow_chip_from_raster(axes, raster, origin, size, *args, **kwargs): ''' Loads and imshow's a chip of an image, and draws the chip at the appropriate coordinates. Does not support custom transforms Parameters ---------- axes : :class:`matplotlib.axes.Axes` The first Axes raster : :class:`rasterio.io.DatasetReader` rasterio dataset object origin : The origin of the chip to load, in y, x, order size : The size of the chip to load, in y, x, order *args : Additional parameters passed to :func:`matplotlib.pyplot.imshow` **kwargs Additional keyword arguments passed to :func:`matplotlib.pyplot.imshow` ''' img_chip = raster.read(1, window=rasterio.windows.Window(origin[1], origin[0], size[1], size[0])) axes.imshow(img_chip, extent=(origin[1]-0.5, origin[1]+img_chip.shape[1]-0.5, origin[0]+img_chip.shape[0]-0.5, origin[0]-0.5), *args, **kwargs)
[docs]def surf(z, cmap='jet', ax=None, x=None, y=None, c=None, **kwargs): ''' Creates an equivalent "surf" plot, like in matlab Parameters ---------- z : :class:`numpy.ndarray` The z data cmap : The colormap used to color based on z height. Default: jet ax : :class:`matplotlib.axes.Axes` The axes to draw on. Default: gca() x : :class:`numpy.ndarray`, optional The x coordinate used to draw. Default uses :func:`numpy.meshgrid` y : :class:`numpy.ndarray`, optional The y coordinate used to draw. Default uses :func:`numpy.meshgrid` c : :class:`numpy.ndarray`, optional A custom array that is fed into the colormap for coloring. Default uses ``z`` **kwargs : dict Additional keyword arguments passed to :meth:`mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface` By default, :meth:`mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface` does not draw the entire mesh, it downsamples it to 50 points instead (for efficiency). To disable downsampling, consider setting ``rstride`` and ``cstride`` to ``1``. Shading is also disabled by default ''' if x is None and y is None: x, y = np.meshgrid(range(np.shape(z)[1]), range(np.shape(z)[0])) elif x is None: x, _ = np.meshgrid(range(np.shape(z)[1]), range(np.shape(z)[0])) elif y is None: _, y = np.meshgrid(range(np.shape(z)[1]), range(np.shape(z)[0])) if c is None: c = z kwargs['shade'] = kwargs.pop('shade', False) if ax is None: ax = plt.gca(projection='3d') scalarMap = mpl.cm.ScalarMappable(norm=plt.Normalize(vmin=c.min(), vmax=c.max()), cmap=cmap) # outputs an array where each C value is replaced with a corresponding color value c_colored = scalarMap.to_rgba(c) surf = ax.plot_surface(x, y, z, facecolors=c_colored, **kwargs) return surf
[docs]class SimpleBubblePicker: ''' Simple class to add a picker with your own text function Currently supports plots (Line2D), scatter (PathCollection), images (AxesImage) and PatchCollection :class:`SimpleBubblePicker` comes with a default text function that will display basic information about any point you click. ''' def __init__(self, fig=None, text_function=None, x_offset=10, y_offset=15, offset_units='dots', bbox={'boxstyle': 'round', 'facecolor': 'wheat', 'alpha': 0.5}, **kwargs): ''' Initialize a :class:`SimpleBubblePicker` for a specific figure. Parameters ---------- fig : :class:`matplotlib.figure.Figure`, optional Default to using the :func:`matplotlib.pyplot.gcf` text_function : :class:`py::function`, optional Defaults to :func:`SimpleBubblePicker.default_text` x_offset : float, optional y_offset : float, optional Offset location of bubble from actual point offset_units : 'str', optional Default: ``dots`` **kwargs : Additional keyword arguments to be passed to text initializer ''' if fig is None: fig = plt.gcf() #: :class:`py::function`: The text generation function. Should take two # arguments self.text_function = text_function or SimpleBubblePicker.default_text #: float: The x offset of the currently drawing bubble self.x_offset = x_offset #: float: The y offset of the currently drawing bubble self.y_offset = y_offset #: str: The units of the offset for the current bubble offset self.offset_units = offset_units kwargs.setdefault('visible', False) kwargs.setdefault('x', 0) kwargs.setdefault('y', 0) kwargs.setdefault('s', 'Intentionally blank') kwargs['bbox'] = bbox self.bubble_kwarg = kwargs fig.canvas.mpl_connect('key_press_event', self.cleanup) fig.canvas.mpl_connect('pick_event', self._picker)
[docs] def new_bubble(self, axes): # I don't know why I need to do this trasnform transform = mpl.transforms.offset_copy(axes.transData, axes.figure, x=self.x_offset, y=self.y_offset, units=self.offset_units) return axes.text(transform=transform, **self.bubble_kwarg)
[docs] def get_bubbles(self, axes): try: return axes.simplebubblepicker_bubbles except AttributeError: axes.simplebubblepicker_bubbles = [] return axes.simplebubblepicker_bubbles
[docs] def get_bubble(self, axes, shift=False): bubbles = self.get_bubbles(axes) if len(bubbles) == 0: shift = True if shift: bubbles.append(self.new_bubble(axes)) return bubbles[-1]
[docs] def default_text(self, event): if isinstance(event.artist, (mpl.lines.Line2D,mpl.collections.PathCollection)): return '\n'.join([f'x: {self.x}', f'y: {self.y}', f'ind: {self.ind}']) elif isinstance(event.artist, mpl.image.AxesImage): return '\n'.join([f'x: {self.x}', f'y: {self.y}', f'data: {self.image[self.y, self.x, ...]}']) elif isinstance(event.artist, mpl.text.Text): return event.artist.get_text() elif isinstance(event.artist, mpl.collections.PatchCollection): return '\n'.join([f'patch: {self.patch}']) elif isinstance(event.artist, mpl.patches.Rectangle): return '\n'.join([f'x: {self.bbox.x0} - {self.bbox.x1}', f'y: {self.bbox.y0} - {self.bbox.y1}']) else: return 'Todo'
[docs] def cleanup(self, event): self.event2 = event if event.key == 'escape': bubbles = self.get_bubbles(event.inaxes) for bubble in bubbles: event.inaxes.texts.remove(bubble) bubbles.clear()
def _picker(self, event): self.event = event # for debugging if isinstance(event, mpl.backend_bases.PickEvent) and \ event.mouseevent.button == mpl.backend_bases.MouseButton.LEFT: if isinstance(event.artist, mpl.lines.Line2D): self.xdata = event.artist.get_xdata() self.ydata = event.artist.get_ydata() self.ind = event.ind # self.x = np.take(self.xdata, self.ind)[0] # self.y = np.take(self.ydata, self.ind)[0] self.x = self.xdata[self.ind[0]] self.y = self.ydata[self.ind[0]] self.text_x = self.x self.text_y = self.y elif isinstance(event.artist, mpl.image.AxesImage): self.image = event.artist.get_array() extent = event.artist.get_extent() x_scale = (extent[1]-extent[0]) / self.image.shape[1] y_scale = (extent[3]-extent[2]) / self.image.shape[0] if event.artist.origin == 'upper': y_scale = -y_scale if y_scale < 0: y_offset = max(extent[3], extent[2]) else: y_offset = min(extent[3], extent[2]) if x_scale < 0: x_offset = max(extent[1], extent[0]) else: x_offset = min(extent[1], extent[0]) self.x = math.floor((event.mouseevent.xdata-x_offset)/x_scale) self.y = math.floor((event.mouseevent.ydata-y_offset)/y_scale) self.text_x = (self.x + 0.5) * x_scale + x_offset self.text_y = (self.y + 0.5) * y_scale + y_offset elif isinstance(event.artist, mpl.collections.PathCollection): # scatter self.xdata = event.artist.get_offsets().data[:,0] self.ydata = event.artist.get_offsets().data[:,1] self.ind = event.ind self.x = self.xdata[self.ind[0]] self.y = self.ydata[self.ind[0]] self.text_x = self.x self.text_y = self.y elif isinstance(event.artist, mpl.collections.PatchCollection): self.ind = event.ind self.patches = event.artist.get_paths() self.patch = self.patches[self.ind[0]] self.text_x, self.text_y = self.patch.get_extents().corners().mean(axis=0) elif isinstance(event.artist, mpl.text.Text): self.x, self.y = event.artist.get_position() self.text_x = self.x self.text_y = self.y elif isinstance(event.artist, mpl.patches.Rectangle): self.bbox = event.artist.get_bbox() # self.x, self.y = event.artist.get_bbox() self.text_x = (self.bbox.x0 + self.bbox.x1)/2 self.text_y = (self.bbox.y0 + self.bbox.y1)/2 else: plt.gca().set_title('oops') return bubble = self.get_bubble(event.mouseevent.inaxes, event.mouseevent.key == "shift") bubble.set_visible(True) bubble.set_x(self.text_x) bubble.set_y(self.text_y) bubble.set_text(self.text_function(self, event))
[docs]def auto_fit_fontsize(text, width=None, height=None, step_size=1, min_font_size=1, max_font_size=100): '''Auto-fit the fontsize of a text object. Args: text (matplotlib.text.Text) width (float): allowed width in data coordinates height (float): allowed height in data coordinates ''' # https://stackoverflow.com/a/61233097/4166604 if min_font_size < 1: raise ValueError('The minimum font size cannot be less than 1.0') def fit(size): text.set_fontsize(size) renderer = text.axes.figure.canvas.get_renderer() bbox_text = text.get_window_extent(renderer=renderer) # transform bounding box to data coordinates bbox_text = mpl.transforms.Bbox(text.axes.transData.inverted().transform(bbox_text)) fits_width = bbox_text.width - width if width else 0.0 fits_height = bbox_text.height - height if height else 0.0 # print(bbox_text.width, text.get_fontsize(), fits_width if abs(fits_width) > abs(fits_height) else fits_height) return fits_width if abs(fits_width) > abs(fits_height) else fits_height try: import scipy.optimize scipy.optimize.bisect(fit, a=min_font_size, b=max_font_size, xtol=step_size) # This happens when both min and max are positive which indicates no zero, # so the desires size is too small, so pick the min except ValueError: text.set_fontsize(min_font_size) except ImportError: # Brute force sz = max_font_size while sz > min_font_size and fit(sz) > 0: sz -= step_size