import plotly.graph_objs as go from _plotly_utils.basevalidators import ColorscaleValidator from ._core import apply_default_cascade, init_figure, configure_animation_controls from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types import pandas as pd import numpy as np import itertools from plotly.utils import image_array_to_data_uri try: import xarray xarray_imported = True except ImportError: xarray_imported = False _float_types = [] def _vectorize_zvalue(z, mode="max"): alpha = 255 if mode == "max" else 0 if z is None: return z elif np.isscalar(z): return [z] * 3 + [alpha] elif len(z) == 1: return list(z) * 3 + [alpha] elif len(z) == 3: return list(z) + [alpha] elif len(z) == 4: return z else: raise ValueError( "zmax can be a scalar, or an iterable of length 1, 3 or 4. " "A value of %s was passed for zmax." % str(z) ) def _infer_zmax_from_type(img): dt = img.dtype.type rtol = 1.05 if dt in _integer_types: return _integer_ranges[dt][1] else: im_max = img[np.isfinite(img)].max() if im_max <= 1 * rtol: return 1 elif im_max <= 255 * rtol: return 255 elif im_max <= 65535 * rtol: return 65535 else: return 2 ** 32 def imshow( img, zmin=None, zmax=None, origin=None, labels={}, x=None, y=None, animation_frame=None, facet_col=None, facet_col_wrap=None, facet_col_spacing=None, facet_row_spacing=None, color_continuous_scale=None, color_continuous_midpoint=None, range_color=None, title=None, template=None, width=None, height=None, aspect=None, contrast_rescaling=None, binary_string=None, binary_backend="auto", binary_compression_level=4, binary_format="png", ): """ Display an image, i.e. data on a 2D regular raster. Parameters ---------- img: array-like image, or xarray The image data. Supported array shapes are - (M, N): an image with scalar data. The data is visualized using a colormap. - (M, N, 3): an image with RGB values. - (M, N, 4): an image with RGBA values, i.e. including transparency. zmin, zmax : scalar or iterable, optional zmin and zmax define the scalar range that the colormap covers. By default, zmin and zmax correspond to the min and max values of the datatype for integer datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For a multichannel image of floats, the max of the image is computed and zmax is the smallest power of 256 (1, 255, 65535) greater than this max value, with a 5% tolerance. For a single-channel image, the max of the image is used. Overridden by range_color. origin : str, 'upper' or 'lower' (default 'upper') position of the [0, 0] pixel of the image array, in the upper left or lower left corner. The convention 'upper' is typically used for matrices and images. labels : dict with str keys and str values (default `{}`) Sets names used in the figure for axis titles (keys ``x`` and ``y``), colorbar title and hoverlabel (key ``color``). The values should correspond to the desired label to be displayed. If ``img`` is an xarray, dimension names are used for axis titles, and long name for the colorbar title (unless overridden in ``labels``). Possible keys are: x, y, and color. x, y: list-like, optional x and y are used to label the axes of single-channel heatmap visualizations and their lengths must match the lengths of the second and first dimensions of the img argument. They are auto-populated if the input is an xarray. animation_frame: int or str, optional (default None) axis number along which the image array is sliced to create an animation plot. If `img` is an xarray, `animation_frame` can be the name of one the dimensions. facet_col: int or str, optional (default None) axis number along which the image array is sliced to create a facetted plot. If `img` is an xarray, `facet_col` can be the name of one the dimensions. facet_col_wrap: int Maximum number of facet columns. Wraps the column variable at this width, so that the column facets span multiple rows. Ignored if `facet_col` is None. facet_col_spacing: float between 0 and 1 Spacing between facet columns, in paper units. Default is 0.02. facet_row_spacing: float between 0 and 1 Spacing between facet rows created when ``facet_col_wrap`` is used, in paper units. Default is 0.0.7. color_continuous_scale : str or list of str colormap used to map scalar data to colors (for a 2D image). This parameter is not used for RGB or RGBA images. If a string is provided, it should be the name of a known color scale, and if a list is provided, it should be a list of CSS- compatible colors. color_continuous_midpoint : number If set, computes the bounds of the continuous color scale to have the desired midpoint. Overridden by range_color or zmin and zmax. range_color : list of two numbers If provided, overrides auto-scaling on the continuous color scale, including overriding `color_continuous_midpoint`. Also overrides zmin and zmax. Used only for single-channel images. title : str The figure title. template : str or dict or plotly.graph_objects.layout.Template instance The figure template name or definition. width : number The figure width in pixels. height: number The figure height in pixels. aspect: 'equal', 'auto', or None - 'equal': Ensures an aspect ratio of 1 or pixels (square pixels) - 'auto': The axes is kept fixed and the aspect ratio of pixels is adjusted so that the data fit in the axes. In general, this will result in non-square pixels. - if None, 'equal' is used for numpy arrays and 'auto' for xarrays (which have typically heterogeneous coordinates) contrast_rescaling: 'minmax', 'infer', or None how to determine data values corresponding to the bounds of the color range, when zmin or zmax are not passed. If `minmax`, the min and max values of the image are used. If `infer`, a heuristic based on the image data type is used. binary_string: bool, default None if True, the image data are first rescaled and encoded as uint8 and then passed to plotly.js as a b64 PNG string. If False, data are passed unchanged as a numerical array. Setting to True may lead to performance gains, at the cost of a loss of precision depending on the original data type. If None, use_binary_string is set to True for multichannel (eg) RGB arrays, and to False for single-channel (2D) arrays. 2D arrays are represented as grayscale and with no colorbar if use_binary_string is True. binary_backend: str, 'auto' (default), 'pil' or 'pypng' Third-party package for the transformation of numpy arrays to png b64 strings. If 'auto', Pillow is used if installed, otherwise pypng. binary_compression_level: int, between 0 and 9 (default 4) png compression level to be passed to the backend when transforming an array to a png b64 string. Increasing `binary_compression` decreases the size of the png string, but the compression step takes more time. For most images it is not worth using levels greater than 5, but it's possible to test `len(fig.data[0].source)` and to time the execution of `imshow` to tune the level of compression. 0 means no compression (not recommended). binary_format: str, 'png' (default) or 'jpg' compression format used to generate b64 string. 'png' is recommended since it uses lossless compression, but 'jpg' (lossy) compression can result if smaller binary strings for natural images. Returns ------- fig : graph_objects.Figure containing the displayed image See also -------- plotly.graph_objects.Image : image trace plotly.graph_objects.Heatmap : heatmap trace Notes ----- In order to update and customize the returned figure, use `go.Figure.update_traces` or `go.Figure.update_layout`. If an xarray is passed, dimensions names and coordinates are used for axes labels and ticks. """ args = locals() apply_default_cascade(args) labels = labels.copy() nslices_facet = 1 if facet_col is not None: if isinstance(facet_col, str): facet_col = img.dims.index(facet_col) nslices_facet = img.shape[facet_col] facet_slices = range(nslices_facet) ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet nrows = ( nslices_facet // ncols + 1 if nslices_facet % ncols else nslices_facet // ncols ) else: nrows = 1 ncols = 1 if animation_frame is not None: if isinstance(animation_frame, str): animation_frame = img.dims.index(animation_frame) nslices_animation = img.shape[animation_frame] animation_slices = range(nslices_animation) slice_dimensions = (facet_col is not None) + ( animation_frame is not None ) # 0, 1, or 2 facet_label = None animation_label = None img_is_xarray = False # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): dims = list(img.dims) img_is_xarray = True if facet_col is not None: facet_slices = img.coords[img.dims[facet_col]].values _ = dims.pop(facet_col) facet_label = img.dims[facet_col] if animation_frame is not None: animation_slices = img.coords[img.dims[animation_frame]].values _ = dims.pop(animation_frame) animation_label = img.dims[animation_frame] y_label, x_label = dims[0], dims[1] # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: if np.issubdtype(img.coords[ax].dtype, np.datetime64): img.coords[ax] = img.coords[ax].astype(str) if x is None: x = img.coords[x_label].values if y is None: y = img.coords[y_label].values if aspect is None: aspect = "auto" if labels.get("x", None) is None: labels["x"] = x_label if labels.get("y", None) is None: labels["y"] = y_label if labels.get("animation_frame", None) is None: labels["animation_frame"] = animation_label if labels.get("facet_col", None) is None: labels["facet_col"] = facet_label if labels.get("color", None) is None: labels["color"] = xarray.plot.utils.label_from_attrs(img) labels["color"] = labels["color"].replace("\n", "
") else: if hasattr(img, "columns") and hasattr(img.columns, "__len__"): if x is None: x = img.columns if labels.get("x", None) is None and hasattr(img.columns, "name"): labels["x"] = img.columns.name or "" if hasattr(img, "index") and hasattr(img.index, "__len__"): if y is None: y = img.index if labels.get("y", None) is None and hasattr(img.index, "name"): labels["y"] = img.index.name or "" if labels.get("x", None) is None: labels["x"] = "" if labels.get("y", None) is None: labels["y"] = "" if labels.get("color", None) is None: labels["color"] = "" if aspect is None: aspect = "equal" # --- Set the value of binary_string (forbidden for pandas) if isinstance(img, pd.DataFrame): if binary_string: raise ValueError("Binary strings cannot be used with pandas arrays") is_dataframe = True else: is_dataframe = False # --------------- Starting from here img is always a numpy array -------- img = np.asanyarray(img) # Reshape array so that animation dimension comes first, then facets, then images if facet_col is not None: img = np.moveaxis(img, facet_col, 0) if animation_frame is not None and animation_frame < facet_col: animation_frame += 1 facet_col = True if animation_frame is not None: img = np.moveaxis(img, animation_frame, 0) animation_frame = True args["animation_frame"] = ( "animation_frame" if labels.get("animation_frame") is None else labels["animation_frame"] ) iterables = () if animation_frame is not None: iterables += (range(nslices_animation),) if facet_col is not None: iterables += (range(nslices_facet),) # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe # Cast bools to uint8 (also one byte) if img.dtype == np.bool: img = 255 * img.astype(np.uint8) if range_color is not None: zmin = range_color[0] zmax = range_color[1] # -------- Contrast rescaling: either minmax or infer ------------------ if contrast_rescaling is None: contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer" # We try to set zmin and zmax only if necessary, because traces have good defaults if contrast_rescaling == "minmax": # When using binary_string and minmax we need to set zmin and zmax to rescale the image if (zmin is not None or binary_string) and zmax is None: zmax = img.max() if (zmax is not None or binary_string) and zmin is None: zmin = img.min() else: # For uint8 data and infer we let zmin and zmax to be None if passed as None if zmax is None and img.dtype != np.uint8: zmax = _infer_zmax_from_type(img) if zmin is None and zmax is not None: zmin = 0 # For 2d data, use Heatmap trace, unless binary_string is True if img.ndim == 2 + slice_dimensions and not binary_string: y_index = slice_dimensions if y is not None and img.shape[y_index] != len(y): raise ValueError( "The length of the y vector must match the length of the first " + "dimension of the img matrix." ) x_index = slice_dimensions + 1 if x is not None and img.shape[x_index] != len(x): raise ValueError( "The length of the x vector must match the length of the second " + "dimension of the img matrix." ) traces = [ go.Heatmap(x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i)) for i, index_tup in enumerate(itertools.product(*iterables)) ] autorange = True if origin == "lower" else "reversed" layout = dict(yaxis=dict(autorange=autorange)) if aspect == "equal": layout["xaxis"] = dict(scaleanchor="y", constrain="domain") layout["yaxis"]["constrain"] = "domain" colorscale_validator = ColorscaleValidator("colorscale", "imshow") layout["coloraxis1"] = dict( colorscale=colorscale_validator.validate_coerce( args["color_continuous_scale"] ), cmid=color_continuous_midpoint, cmin=zmin, cmax=zmax, ) if labels["color"]: layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"]) # For 2D+RGB data, use Image trace elif ( img.ndim >= 3 and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string) ) or (img.ndim == 2 and binary_string): rescale_image = True # to check whether image has been modified if zmin is not None and zmax is not None: zmin, zmax = ( _vectorize_zvalue(zmin, mode="min"), _vectorize_zvalue(zmax, mode="max"), ) x0, y0, dx, dy = (None,) * 4 error_msg_xarray = ( "Non-numerical coordinates were passed with xarray `img`, but " "the Image trace cannot handle it. Please use `binary_string=False` " "for 2D data or pass instead the numpy array `img.values` to `px.imshow`." ) if x is not None: x = np.asanyarray(x) if np.issubdtype(x.dtype, np.number): x0 = x[0] dx = x[1] - x[0] else: error_msg = ( error_msg_xarray if img_is_xarray else ( "Only numerical values are accepted for the `x` parameter " "when an Image trace is used." ) ) raise ValueError(error_msg) if y is not None: y = np.asanyarray(y) if np.issubdtype(y.dtype, np.number): y0 = y[0] dy = y[1] - y[0] else: error_msg = ( error_msg_xarray if img_is_xarray else ( "Only numerical values are accepted for the `y` parameter " "when an Image trace is used." ) ) raise ValueError(error_msg) if binary_string: if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img rescale_image = False elif img.ndim == 2 + slice_dimensions: # single-channel image img_rescaled = rescale_intensity( img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 ) else: img_rescaled = np.stack( [ rescale_intensity( img[..., ch], in_range=(zmin[ch], zmax[ch]), out_range=np.uint8, ) for ch in range(img.shape[-1]) ], axis=-1, ) img_str = [ image_array_to_data_uri( img_rescaled[index_tup], backend=binary_backend, compression=binary_compression_level, ext=binary_format, ) for index_tup in itertools.product(*iterables) ] traces = [ go.Image(source=img_str_slice, name=str(i), x0=x0, y0=y0, dx=dx, dy=dy) for i, img_str_slice in enumerate(img_str) ] else: colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" traces = [ go.Image( z=img[index_tup], zmin=zmin, zmax=zmax, colormodel=colormodel, x0=x0, y0=y0, dx=dx, dy=dy, ) for index_tup in itertools.product(*iterables) ] layout = {} if origin == "lower" or (dy is not None and dy < 0): layout["yaxis"] = dict(autorange=True) if dx is not None and dx < 0: layout["xaxis"] = dict(autorange="reversed") else: raise ValueError( "px.imshow only accepts 2D single-channel, RGB or RGBA images. " "An image of shape %s was provided." "Alternatively, 3- or 4-D single or multichannel datasets can be" "visualized using the `facet_col` or/and `animation_frame` arguments." % str(img.shape) ) # Now build figure col_labels = [] if facet_col is not None: slice_label = ( "facet_col" if labels.get("facet_col") is None else labels["facet_col"] ) col_labels = ["%s=%d" % (slice_label, i) for i in facet_slices] fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) for attr_name in ["height", "width"]: if args[attr_name]: layout[attr_name] = args[attr_name] if args["title"]: layout["title_text"] = args["title"] elif args["template"].layout.margin.t is None: layout["margin"] = {"t": 60} frame_list = [] for index, trace in enumerate(traces): if (facet_col and index < nrows * ncols) or index == 0: fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) if animation_frame is not None: for i, index in zip(range(nslices_animation), animation_slices): frame_list.append( dict( data=traces[nslices_facet * i : nslices_facet * (i + 1)], layout=layout, name=str(index), ) ) if animation_frame: fig.frames = frame_list fig.update_layout(layout) # Hover name, z or color if binary_string and rescale_image and not np.all(img == img_rescaled): # we rescaled the image, hence z is not displayed in hover since it does # not correspond to img values hovertemplate = "%s: %%{x}
%s: %%{y}" % ( labels["x"] or "x", labels["y"] or "y", ) else: if trace["type"] == "heatmap": hover_name = "%{z}" elif img.ndim == 2: hover_name = "%{z[0]}" elif img.ndim == 3 and img.shape[-1] == 3: hover_name = "[%{z[0]}, %{z[1]}, %{z[2]}]" else: hover_name = "%{z}" hovertemplate = "%s: %%{x}
%s: %%{y}
%s: %s" % ( labels["x"] or "x", labels["y"] or "y", labels["color"] or "color", hover_name, ) fig.update_traces(hovertemplate=hovertemplate) if labels["x"]: fig.update_xaxes(title_text=labels["x"], row=1) if labels["y"]: fig.update_yaxes(title_text=labels["y"], col=1) configure_animation_controls(args, go.Image, fig) fig.update_layout(template=args["template"], overwrite=True) return fig