Source code for neuroplot.plot.single.single_plot

from pathlib import Path
from typing import Callable, Sequence

import matplotlib.pyplot as plt
import numpy as np
import torchio as tio
from matplotlib.figure import Figure


[docs] class SinglePlot: """ To plot a single neuroimage. 2D slices will be plotted via the method :py:meth:`plot`. The user can choose which anatomical axes to plot, and which slice to plot along the axes. The title of the figure can be changed between plots using :py:meth:`set_title`. Just a random reference to :footcite:t:`Harris:2020`. And another to Matplotlib\\ :footcite:`Hunter:2007`. Parameters ---------- axes : int | Sequence[int] | None, default=None The axis (or axes) to plot, among ``0`` (sagittal axis), ``1`` (coronal) or ``2`` (axial). Can be passed as a single axis, or a list of axes. If ``None``, the three axes will be plotted. slices : int | Sequence[int] | None, default=None The slice to plot for each axis. If ``None``, the middle slice will be plotted. Otherwise, the **number of slices passed must be equal to the number of plotted axes** (equal to :math:`3` if ``axes=None``). transforms : Sequence[Callable[[np.ndarray], np.ndarray]] | None, default=None Potential transforms to apply to the image before plotting. See :py:mod:`neuroplot.transforms`. .. important:: No matter the transforms passed, the image will first be reoriented to the :term:`RAS+` coordinate system. figsize : tuple[float, float] | None, default=None The size of the figure. See :py:func:`matplotlib.pyplot.figure` for more details. title : str | None, default=None A potential title for the figures that will be plotted. Raises ------ AssertionError If the number of slices passed is not equal to the number of plotted axes. Examples -------- .. code-block:: python from neuroplot.plot.single import SinglePlot from neuroplot.transforms import RescaleIntensity plotter = SinglePlot(axes=[0, 2], slices=[55, 167], transforms=[RescaleIntensity()]) .. code-block:: python >>> plotter.set_title("A first image") >>> plotter.plot("data/example_1.nii.gz") .. code-block:: python >>> plotter.set_title("Another image") >>> plotter.plot("data/example_2.nii.gz") See Also -------- :py:class:`neuroplot.plot.multiple.MultiplePlot` To plot multiple neuroimages in a grid of subplots. References ---------- .. footbibliography:: """ def __init__( self, axes: int | Sequence[int] | None = None, slices: int | Sequence[int] | None = None, transforms: Sequence[Callable[[np.ndarray], np.ndarray]] | None = None, figsize: tuple[float, float] | None = None, title: str | None = None, ) -> None: self.axes, self.slices = self._check_axes_and_slices(axes, slices) if transforms is None: transforms = [] transforms = [tio.ToCanonical()] + transforms self.transforms = tio.Compose(transforms) self.figsize = figsize self.title = title
[docs] def set_title(self, title: str | None) -> None: """ To change the title of the future plot. Parameters ---------- title : Optional[str] The new title. """ self.title = title
[docs] def plot( self, img_path: str | Path, show: bool = True, ) -> Figure: """ Builds a plot of an image. Parameters ---------- img_path : Union[str, Path] The path to the image to plot. show : bool, default=True Whether to display the figure. Returns ------- matplotlib.figure.Figure The figure with the desired 2D slices. Raises ------ IndexError If a slice passed in ``slices`` is out of bounds in this image. """ image = tio.ScalarImage(path=img_path) np_image = self.transforms(image).numpy().squeeze(0) slices = self._get_slice_indices(np_image) fig, plot_axes = plt.subplots(1, len(self.axes), figsize=self.figsize) if len(self.axes) == 1: plot_axes = [plot_axes] # turn it into an iterable for ax, slc, plot_axis in zip(self.axes, slices, plot_axes): plot_axis.set_xlabel(f"axis={ax}, slice={slc}") plot_axis.imshow(self._get_slice(np_image, ax, slc), cmap="gray") if self.title: fig.suptitle(self.title) if show: plt.show() return fig
@staticmethod def _check_axes_and_slices( axes: int | Sequence[int] | None, slices: int | Sequence[int] | None, ) -> tuple[Sequence[int], Sequence[int] | None]: """ To check that 'axes' and 'slices' are consistent. """ if axes is None: axes = [0, 1, 2] elif isinstance(axes, int): axes = [axes] if slices is None: n_slices = len(axes) elif isinstance(slices, int): n_slices = 1 slices = [slices] else: n_slices = len(slices) assert ( len(axes) == n_slices ), f"Got {len(axes)} elements for 'axes', but {n_slices} for 'slices'." return axes, slices def _get_slice_indices( self, image: np.ndarray, ) -> Sequence[int]: """ Checks that the wanted slice is not out of bounds for this image. If ``slices`` was set to ``None``, computes the slice index. """ spatial_shape = np.array(image.shape) if self.slices: slices = self.slices else: slices = spatial_shape // 2 for ax, slc in zip(self.axes, slices): if slc >= spatial_shape[ax]: raise IndexError(f"Slice {slc} is out of bounds in axis {ax}.") return slices @staticmethod def _get_slice(image: np.ndarray, ax: int, slc: int) -> np.ndarray: """ Gets the slice from the image. """ indices = [slice(None)] * len(image.shape) indices[ax] = slc return image[tuple(indices)]