"""
Utility functions for querying and manipulating dimensional axis metadata.
"""
import logging
from typing import List, Tuple
import numpy as np
import scyjava as sj
import xarray as xr
from jpype import JException, JObject
from imagej._java import jc
from imagej.images import is_arraylike as _is_arraylike
from imagej.images import is_xarraylike as _is_xarraylike
_logger = logging.getLogger(__name__)
[docs]def get_axes(
rai: "jc.RandomAccessibleInterval",
) -> List["jc.CalibratedAxis"]:
"""
imagej.dims.get_axes(image) is deprecated. Use image.dim_axes instead.
"""
_logger.warning(
"imagej.dims.get_axes(image) is deprecated. Use image.dim_axes instead."
)
return [
(JObject(rai.axis(idx), jc.CalibratedAxis))
for idx in range(rai.numDimensions())
]
[docs]def get_axis_types(rai: "jc.RandomAccessibleInterval") -> List["jc.AxisType"]:
"""
imagej.dims.get_axis_types(image) is deprecated. Use this code instead:
axis_types = [axis.type() for axis in image.dim_axes]
"""
_logger.warning(
"imagej.dims.get_axis_types(image) is deprecated. Use this code instead:\n"
+ "\n"
+ " axis_types = [axis.type() for axis in image.dim_axes]"
)
if _has_axis(rai):
rai_dims = get_dims(rai)
for i in range(len(rai_dims)):
if rai_dims[i].lower() == "c":
rai_dims[i] = "Channel"
if rai_dims[i].lower() == "t":
rai_dims[i] = "Time"
rai_axis_types = []
for i in range(len(rai_dims)):
rai_axis_types.append(jc.Axes.get(rai_dims[i]))
return rai_axis_types
else:
raise AttributeError(
f"Unsupported Java type: {type(rai)} has no axis attribute."
)
[docs]def get_dims(image) -> List[str]:
"""
imagej.dims.get_dims(image) is deprecated. Use image.shape and image.dims instead.
"""
_logger.warning(
"imagej.dims.get_dims(image) is deprecated. Use image.shape and image.dims "
"instead."
)
if _is_xarraylike(image):
return image.dims
if _is_arraylike(image):
return image.shape
if hasattr(image, "axis"):
axes = get_axes(image)
return _get_axis_labels(axes)
if isinstance(image, jc.RandomAccessibleInterval):
return list(image.dimensionsAsLongArray())
if isinstance(image, jc.ImagePlus):
shape = image.getDimensions()
return [axis for axis in shape if axis > 1]
raise TypeError(f"Unsupported image type: {image}\n No dimensions or shape found.")
[docs]def get_shape(image) -> List[int]:
"""
imagej.dims.get_shape(image) is deprecated. Use image.shape instead.
"""
_logger.warning(
"imagej.dims.get_shape(image) is deprecated. Use image.shape instead."
)
if _is_arraylike(image):
return list(image.shape)
if not sj.isjava(image):
raise TypeError("Unsupported type: " + str(type(image)))
if isinstance(image, jc.Dimensions):
return [image.dimension(d) for d in range(image.numDimensions())]
if isinstance(image, jc.ImagePlus):
shape = image.getDimensions()
return [axis for axis in shape if axis > 1]
raise TypeError(f"Unsupported Java type: {str(sj.jclass(image).getName())}")
[docs]def reorganize(
rai: "jc.RandomAccessibleInterval", permute_order: List[int]
) -> "jc.ImgPlus":
"""Reorganize the dimension order of a RandomAccessibleInterval.
Permute the dimension order of an input RandomAccessibleInterval using
a List of ints (i.e. permute_order) to determine the shape of the output ImgPlus.
:param rai: A RandomAccessibleInterval,
:param permute_order: List of int in which to permute the RandomAccessibleInterval.
:return: A permuted ImgPlus.
"""
img = _dataset_to_imgplus(rai)
# check for dimension count mismatch
dim_num = rai.numDimensions()
if len(permute_order) != dim_num:
raise ValueError(
f"Mismatched dimension count: {len(permute_order)} != {dim_num}"
)
# get ImageJ resources
ImgView = sj.jimport("net.imglib2.img.ImgView")
# copy dimensional axes into
axes = []
for i in range(dim_num):
old_dim = permute_order[i]
axes.append(img.axis(old_dim))
# repeatedly permute the image dimensions into shape
rai = img.getImg()
for i in range(dim_num):
old_dim = permute_order[i]
if old_dim == i:
continue
rai = jc.Views.permute(rai, old_dim, i)
# update index mapping acccordingly...this is hairy ;-)
for j in range(dim_num):
if permute_order[j] == i:
permute_order[j] = old_dim
break
permute_order[i] = i
return jc.ImgPlus(ImgView.wrap(rai), img.getName(), axes)
[docs]def prioritize_rai_axes_order(
axis_types: List["jc.AxisType"], ref_order: List["jc.AxisType"]
) -> List[int]:
"""Prioritize the axes order to match a reference order.
The input List of 'AxisType' from the image to be permuted
will be prioritized to match (where dimensions exist) to
a reference order (e.g. _python_rai_ref_order).
:param axis_types: List of 'net.imagej.axis.AxisType' from image.
:param ref_order: List of 'net.imagej.axis.AxisType' from reference order.
:return: List of int for permuting a image (e.g. [0, 4, 3, 1, 2])
"""
permute_order = []
for axis in ref_order:
for i in range(len(axis_types)):
if axis == axis_types[i]:
permute_order.append(i)
for i in range(len(axis_types)):
if axis_types[i] not in ref_order:
permute_order.append(i)
return permute_order
def _assign_axes(xarr: xr.DataArray):
"""
Obtain xarray axes names, origin, and scale and convert into ImageJ Axis;
currently supports EnumeratedAxis
:param xarr: xarray that holds the units
:return: A list of ImageJ Axis with the specified origin and scale
"""
Double = sj.jimport("java.lang.Double")
axes = [""] * len(xarr.dims)
# try to get EnumeratedAxis, if not then default to LinearAxis in the loop
try:
EnumeratedAxis = _get_enumerated_axis()
except (JException, TypeError):
EnumeratedAxis = None
for dim in xarr.dims:
axis_str = _convert_dim(dim, direction="java")
ax_type = jc.Axes.get(axis_str)
ax_num = _get_axis_num(xarr, dim)
scale = _get_scale(xarr.coords[dim])
if scale is None:
_logger.warning(
f"The {ax_type.label} axis is non-numeric and is translated "
"to a linear index."
)
doub_coords = [
Double(np.double(x)) for x in np.arange(len(xarr.coords[dim]))
]
else:
doub_coords = [Double(np.double(x)) for x in xarr.coords[dim]]
# EnumeratedAxis is a new axis made for xarray, so is only present in
# ImageJ versions that are released later than March 2020.
# This actually returns a LinearAxis if using an earlier version.
if EnumeratedAxis is not None:
java_axis = EnumeratedAxis(ax_type, sj.to_java(doub_coords))
else:
java_axis = _get_linear_axis(ax_type, sj.to_java(doub_coords))
axes[ax_num] = java_axis
return axes
def _ends_with_channel_axis(xarr: xr.DataArray) -> bool:
"""Check if xarray.DataArray ends in the channel dimension.
:param xarr: xarray.DataArray to check.
:return: Boolean
"""
ends_with_axis = xarr.dims[len(xarr.dims) - 1].lower() in ["c", "ch", "channel"]
return ends_with_axis
def _get_axis_num(xarr: xr.DataArray, axis):
"""
Get the xarray -> java axis number due to inverted axis order for C style numpy
arrays (default)
:param xarr: Xarray to convert
:param axis: Axis number to convert
:return: Axis idx in java
"""
py_axnum = xarr.get_axis_num(axis)
if np.isfortran(xarr.values):
return py_axnum
if _ends_with_channel_axis(xarr):
if axis == len(xarr.dims) - 1:
return axis
else:
return len(xarr.dims) - py_axnum - 2
else:
return len(xarr.dims) - py_axnum - 1
def _get_axes_coords(
axes: List["jc.CalibratedAxis"], dims: List[str], shape: Tuple[int]
) -> dict:
"""
Get xarray style coordinate list dictionary from a dataset
:param axes: List of ImageJ axes
:param dims: List of axes labels for each dataset axis
:param shape: F-style, or reversed C-style, shape of axes numpy array.
:return: Dictionary of coordinates for each axis.
"""
coords = {
dims[idx]: [
axes[idx].calibratedValue(position) for position in range(shape[idx])
]
for idx in range(len(dims))
}
return coords
def _get_scale(axis):
"""
Get the scale of an axis, assuming it is linear and so the scale is simply
second - first coordinate.
:param axis: A 1D list like entry accessible with indexing, which contains the
axis coordinates
:return: The scale for this axis or None if it is a non-numeric scale.
"""
try:
# HACK: This axis length check is a work around for singleton dimensions.
# You can't calculate the slope of a singleton dimension.
# This section will be removed when axis-scale-logic is merged.
if len(axis) <= 1:
return 1
else:
return axis.values[1] - axis.values[0]
except TypeError:
return None
def _get_enumerated_axis():
"""Get EnumeratedAxis.
EnumeratedAxis is only in releases later than March 2020. If using
an older version of ImageJ without EnumeratedAxis, use
_get_linear_axis() instead.
"""
return sj.jimport("net.imagej.axis.EnumeratedAxis")
def _get_linear_axis(axis_type: "jc.AxisType", values):
"""Get linear axis.
This is used if no EnumeratedAxis is found. If EnumeratedAxis
is available, use _get_enumerated_axis() instead.
"""
DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis")
origin = values[0]
scale = values[1] - values[0]
axis = DefaultLinearAxis(axis_type, scale, origin)
return axis
def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus":
"""Get an ImgPlus from a Dataset.
Get an ImgPlus from a Dataset or just return the RandomAccessibleInterval
if its not a Dataset.
:param rai: A RandomAccessibleInterval.
:return: The ImgPlus from a Dataset.
"""
if isinstance(rai, jc.Dataset):
return rai.getImgPlus()
else:
return rai
def _get_axis_labels(axes: List["jc.CalibratedAxis"]) -> List[str]:
"""Get the axes labels from a List of 'CalibratedAxis'.
Extract the axis labels from a List of 'CalibratedAxis'.
:param axes: A List of 'CalibratedAxis'.
:return: A list of the axis labels.
"""
return [str((axes[idx].type().getLabel())) for idx in range(len(axes))]
def _python_rai_ref_order() -> List["jc.AxisType"]:
"""Get the Java style numpy reference order.
Get a List of 'AxisType' in the Python/scikitimage
preferred order. Note that this reference order is
reversed.
:return: List of dimensions in numpy preferred order.
"""
return [jc.Axes.CHANNEL, jc.Axes.X, jc.Axes.Y, jc.Axes.Z, jc.Axes.TIME]
def _convert_dim(dim: str, direction: str) -> str:
"""Convert a dimension to Python/NumPy or ImageJ convention.
Convert a single dimension to Python/NumPy or ImageJ convention by
indicating which direction ('python' or 'java'). A converted dimension
is returned.
:param dim: A dimension to be converted.
:param direction:
'python': Convert a single dimension from ImageJ to Python/NumPy convention.
'java': Convert a single dimension from Python/NumPy to ImageJ convention.
:return: A single converted dimension.
"""
if direction.lower() == "python":
return _to_pydim(dim)
elif direction.lower() == "java":
return _to_ijdim(dim)
else:
return dim
def _convert_dims(dimensions: List[str], direction: str) -> List[str]:
"""Convert a List of dimensions to Python/NumPy or ImageJ conventions.
Convert a List of dimensions to Python/Numpy or ImageJ conventions by
indicating which direction ('python' or 'java'). A List of converted
dimentions is returned.
:param dimensions: List of dimensions (e.g. X, Y, Channel, Z, Time)
:param direction:
'python': Convert dimensions from ImageJ to Python/NumPy conventions.
'java': Convert dimensions from Python/NumPy to ImageJ conventions.
:return: List of converted dimensions.
"""
new_dims = []
if direction.lower() == "python":
for dim in dimensions:
new_dims.append(_to_pydim(dim))
return new_dims
elif direction.lower() == "java":
for dim in dimensions:
new_dims.append(_to_ijdim(dim))
return new_dims
else:
return dimensions
def _validate_dim_order(dim_order: List[str], shape: tuple) -> List[str]:
"""
Validate a List of dimensions. If the dimension list is smaller
fill the rest of the list with "dim_n" (following xarrray convention).
:param dim_order: List of dimensions (e.g. X, Y, Channel, Z, Time)
:param shape: Shape image for the dimension order.
:return: List with "dim_n" dimensions added to match shape length.
"""
dim_len = len(dim_order)
shape_len = len(shape)
if dim_len < shape_len:
d = shape_len - dim_len
for i in range(d):
dim_order.append(f"dim_{i}")
return dim_order
if dim_len > shape_len:
raise ValueError(f"Expected {shape_len} dimensions but got {dim_len}.")
return dim_order
def _has_axis(rai: "jc.RandomAccessibleInterval"):
"""Check if a RandomAccessibleInterval has axes."""
if sj.isjava(rai):
return hasattr(rai, "axis")
else:
False
def _to_pydim(key: str) -> str:
"""Convert ImageJ dimension convention to Python/NumPy."""
pydims = {
"Time": "t",
"slice": "pln",
"Z": "pln",
"Y": "row",
"X": "col",
"Channel": "ch",
}
if key in pydims:
return pydims[key]
else:
return key
def _to_ijdim(key: str) -> str:
"""Convert Python/NumPy dimension convention to ImageJ."""
ijdims = {
"col": "X",
"x": "X",
"row": "Y",
"y": "Y",
"ch": "Channel",
"c": "Channel",
"pln": "Z",
"z": "Z",
"t": "Time",
}
if key in ijdims:
return ijdims[key]
else:
return key