ImageJ2 Python Scripts

Getting Started
User Guides
Tips and Tricks
All Techniques
Basics of script writing
Batch processing
Script Editor
Auto Imports
Running headlessly
Multithreading in Clojure
Multithreading in JavaScript
Chess in Jython
ImageJ Macro
Lisp (Clojure)
Python (Jython)
R (Renjin)
Ruby (JRuby)


This page is a primer of ImageJ2 only Python scripts. It means that the examples included here avoid IJ1 as much as possible, unless it's really necessary.

Note that all the scripts of this page are links from


Stack Projection

# @Dataset data
# @String(label="Dimension to Project", choices={"X", "Y", "Z", "TIME", "CHANNEL"}) projected_dimension
# @String(label="Projection Type", choices={"Max","Mean","Median","Min", "StdDev", "Sum"}) projection_type
# @OUTPUT Dataset output
# @OpService ops
# @DatasetService ds

# Do a projection of a stack. The projection is done along a specified axis.
# The commin use case of this script is to do a maximum Z projection.
from net.imagej.axis import Axes
from net.imagej.ops import Ops

# Select which dimension to project
dim = data.dimensionIndex(getattr(Axes, projected_dimension))

if dim == -1:
    raise Exception("%s dimension not found." % projected_dimension)

if data.dimension(dim) < 2:
    raise Exception("%s dimension has only one frame." % projected_dimension)

# Write the output dimensions
new_dimensions = [data.dimension(d) for d in range(0, data.numDimensions()) if d != dim]

# Create the output image
projected = ops.create().img(new_dimensions)

# Create the op and run it
proj_op = ops.op(getattr(Ops.Stats, projection_type), data)
ops.transform().project(projected, data, proj_op, dim)

# Create the output Dataset
output = ds.create(projected)

Apply Threshold

# @String(label="Threshold Method", required=true, choices={'otsu', 'huang'}) method_threshold
# @Float(label="Relative threshold", required=true, value=1, stepSize=0.1) relative_threshold
# @Dataset data
# @OUTPUT Dataset output
# @OpService ops
# @DatasetService ds

# Apply an automatic threshold from a given method. The threshold value 'threshold_value'
# can be modulated by a relative parameter called 'relative_threshold' (if equal to 1 it does
# not modify 'threshold_value')

from net.imglib2.type.numeric.integer import UnsignedByteType

# Get the histogram
histo ="image.histogram", data)

# Get the threshold
threshold_value ="threshold.%s" % method_threshold, histo)

# Modulate 'threshold_value' by 'relative_threshold'
threshold_value = int(round(threshold_value.get() * relative_threshold))

# We should not have to do that...
threshold_value = UnsignedByteType(threshold_value)

# Apply the threshold
thresholded ="threshold.apply", data, threshold_value)

# Create output Dataset
output = ds.create(thresholded)

# @String(label="Threshold Method", required=true, choices={'otsu', 'huang'}) method_threshold
# @Dataset data
# @OUTPUT Dataset output
# @OpService ops
# @DatasetService ds

# Apply an automatic threshold from a given method.
thresholded ="threshold.%s" % method_threshold, data)

# Create output
output = ds.create(thresholded)

Crop an image

# @Dataset data
# @OUTPUT Dataset output
# @DatasetService ds
# @OpService ops

from net.imagej.axis import Axes
from net.imglib2.util import Intervals

# This function helps to crop a Dataset along an arbitrary number of Axes.
# Intervals to crop are specified easily as a Python dict. 

def get_axis(axis_type):
    return {
        'X': Axes.X,
        'Y': Axes.Y,
        'Z': Axes.Z,
        'TIME': Axes.TIME,
        'CHANNEL': Axes.CHANNEL,
    }.get(axis_type, Axes.Z)

def crop(ops, data, intervals):
    """Crop along a one or more axis.
    intervals : Dict specifying which axis to crop and with what intervals.
                Example :
                intervals = {'X' : [0, 50],
                             'Y' : [0, 50]}
    intervals_start = [data.min(d) for d in range(0, data.numDimensions())]
    intervals_end = [data.max(d) for d in range(0, data.numDimensions())]
    for axis_type, interval in intervals.items():
        index = data.dimensionIndex(get_axis(axis_type))
        intervals_start[index] = interval[0]
        intervals_end[index] = interval[1]
    intervals = Intervals.createMinMax(*intervals_start + intervals_end)
    output ="transform.crop", data, intervals, True)
    return output

# Define the intervals to be cropped
intervals = {'X': [0, 5],
    'Y': [0, 5]}

# Crop the Dataset
output = crop(ops, data, intervals)

# Create output Dataset
output = ds.create(output)

Rotate all the frames of a stack

# @Float(label="Rotation angle (in degree)", required=true, value=90, stepSize=0.1) angle
# @Dataset data
# @OUTPUT Dataset output
# @OpService ops
# @DatasetService ds

# This script rotates all the frame of a stack along the TIME axis to a given angle.
# I found this script over complicated for what it is supposed to do. I hope a simpler way to do this kind of 
# transformation will be avaiable one day. At that time the script would have to be updated.

import math
from net.imagej.axis import Axes
from net.imglib2.interpolation.randomaccess import LanczosInterpolatorFactory
from net.imglib2.realtransform import AffineTransform2D
from net.imglib2.realtransform import RealViews
from net.imglib2.util import Intervals
from net.imglib2.view import Views

def get_axis(axis_type):
    return {
        'X': Axes.X,
        'Y': Axes.Y,
        'Z': Axes.Z,
        'TIME': Axes.TIME,
        'CHANNEL': Axes.CHANNEL,
    }.get(axis_type, Axes.Z)
def crop_along_one_axis(ops, data, intervals, axis_type):
    """Crop along a single axis using Views.
    intervals : List with two values specifying the start and the end of the interval.
    axis_type : Along which axis to crop. Can be ["X", "Y", "Z", "TIME", "CHANNEL"]
    axis = get_axis(axis_type)
    interval_start = [data.min(d) if d != data.dimensionIndex(axis) else intervals[0] for d in range(0, data.numDimensions())]
    interval_end = [data.max(d) if d != data.dimensionIndex(axis) else intervals[1] for d in range(0, data.numDimensions())]
    interval = interval_start + interval_end
    interval = Intervals.createMinMax(*interval)
    output ="transform.crop", data, interval, True)
    return output
# Get the center of the images so we do the rotation according to it
center = [int(round((data.max(d) / 2 + 1))) for d in range(2)]
# Convert angles to radians
angle_rad = angle * math.pi / 180
# Build the affine transformation
affine = AffineTransform2D()
affine.translate([-p for p in center])
# Get the interpolator
interpolator = LanczosInterpolatorFactory()
# Iterate over all frame in the stack
axis = Axes.TIME
output = []
for d in range(data.dimension(axis)):
    # Get the current frame
    frame = crop_along_one_axis(ops, data, [d, d], "TIME")
    # Get the interpolate view of the frame
    extended ="transform.extendZeroView", frame)
    interpolant ="transform.interpolateView", extended, interpolator)
    # Apply the transformation to it
    rotated = RealViews.affine(interpolant, affine)
    # Set the intervals
    rotated = ops.transform().offset(rotated, frame)
output = Views.stack(output)

# Create output Dataset
output = ds.create(output)

Subtract a stack to its first image

# @Dataset data
# @OUTPUT Dataset output
# @OpService ops
# @DatasetService ds

# Subtract the first frame of a stack to all the frames of the given stack along the TIME axis.
# It removes the static elements from a stack. Usefull when you are studying moving objects.
from net.imglib2.util import Intervals
from net.imagej.axis import Axes
# Convert input
converted = ops.convert().float32(data)
# Get the first frame (TODO: find a more convenient way !)
t_dim = data.dimensionIndex(Axes.TIME)
interval_start = []
interval_end = []
for d in range(0, data.numDimensions()):
    if d != t_dim:
        interval_end.append(data.dimension(d) - 1)
intervals = interval_start + interval_end
intervals = Intervals.createMinMax(*intervals)
first_frame = ops.transform().crop(converted, intervals)
# Allocate output memory (wait for hybrid CF version of slice)
subtracted = ops.create().img(converted)
# Create the op
sub_op =  ops.op("math.subtract", first_frame, first_frame)
# Setup the fixed axis
fixed_axis = [d for d in range(0, data.numDimensions()) if d != t_dim]
# Run the op
ops.slice(subtracted, converted, sub_op, fixed_axis)
# Clip image to the input type
clipped = ops.create().img(subtracted, data.getImgPlus().firstElement())
clip_op = ops.op("convert.clip", data.getImgPlus().firstElement(), subtracted.firstElement())
ops.convert().imageType(clipped, subtracted, clip_op)

# Create output Dataset
output = ds.create(clipped)

Apply DOG Filter

# @Dataset data
# @Float(label="Sigma 1 (pixel)", required=true, value=4.2, stepSize=0.1) sigma1
# @Float(label="Sigma 2 (pixel)", required=true, value=1.25, stepSize=0.1) sigma2
# @OUTPUT Dataset output
# @OpService ops
# @DatasetService ds

# Run a DOG filter on all the frames along the TIME axis.
# After the filtering step the image is clipped to match the input type.

from net.imagej.axis import Axes

# Convert data to float 32
converted = ops.convert().float32(data.getImgPlus())
# Allocate output memory (wait for hybrid CF version of slice)
dog = ops.create().img(converted)
# Create the op
dog_op = ops.op("", converted, sigma1, sigma2)
# Setup the fixed axis
t_dim = data.dimensionIndex(Axes.TIME)
fixed_axis = [d for d in range(0, data.numDimensions()) if d != t_dim]
# Run the op
ops.slice(dog, converted, dog_op, fixed_axis)
# Clip image to the input type
clipped = ops.create().img(dog, data.getImgPlus().firstElement())
clip_op = ops.op("convert.clip", data.getImgPlus().firstElement(), dog.firstElement())
ops.convert().imageType(clipped, dog, clip_op)

# Create output Dataset
output = ds.create(clipped)

Apply a mask

# @Dataset data
# @Dataset mask
# @OUTPUT Dataset output

# Given a mask (binary image) and a raw image, remove background pixel from raw by
# keeping only those in the mask (different from 0).

# Note : As specified by @stelfrich on Gitter, the particular case when foreground pixel
# are 1 and background pixels are 0 can be simpler to write with a multiplication of the two
# images.

from net.imglib2.util import Intervals

# Check dimensions are the same for 'data' and 'mask'

if not Intervals.equalDimensions(data, mask):
    raise Exception("Dimensions from input dataset does not match.")

# Create the cursors
output = data.duplicate() 
targetCursor = output.localizingCursor()
dataRA = data.randomAccess()
maskRA = mask.randomAccess()

# Iterate over each pixels of the datasets
while targetCursor.hasNext():
    if maskRA.get().get() == 0:

Retrieve objects/particles from a mask

# @ImageJ ij
# @Dataset data
# @Dataset mask

# This script identify all the particles from a mask and create label regions over which you can iterate.
# The second part of the script display all the detected regions into the IJ1 RoiManager.

from ij.gui import PointRoi
from ij.plugin.frame import RoiManager
from net.imglib2.algorithm.labeling.ConnectedComponents import StructuringElement
from net.imglib2.roi.labeling import LabelRegions

def get_roi_manager(new=False):
    rm = RoiManager.getInstance()
    if not rm:
        rm = RoiManager()
    if new:
    return rm

# Identify particles
img = mask.getImgPlus()
labeled_img = ij.op().run("cca", img, StructuringElement.EIGHT_CONNECTED)

# Create label regions from particles
regions = LabelRegions(labeled_img)
region_labels = list(regions.getExistingLabels())
print("%i regions/particles detected" % len(region_labels))
# Now use IJ1 RoiManager to display the detected regions 
rm = get_roi_manager(new=True)

for label in region_labels:
    region = regions.getLabelRegion(label)

  	# Get the center of mass of the region
    center = region.getCenterOfMass()
    x = center.getDoublePosition(0)
    y = center.getDoublePosition(1)
    roi = PointRoi(x, y)
    if center.numDimensions() >= 3:
        z = center.getDoublePosition(2)
    # You can also iterate over the `data` pixel by LabelRegion
    cursor = region.localizingCursor()
    dataRA = data.randomAccess()
    while cursor.hasNext():
        x = cursor.getDoublePosition(0)
        y = cursor.getDoublePosition(1)
        # Pixel of `data`
        pixel = dataRA.get()
        # Do whatever you want here
        # print(x, y, pixel)

Manual Simple Registration on Stack

# @Dataset ds
# @OUTPUT Dataset output
# @OpService ops
# @LogService log
# @DatasetService datasetService

# This script translates individual slices in a stack according to single
# point ROIs (defined in the IJ1 ROIManager). If slices exist in between specified ROIs,
# a linear translation from one ROI to the next is applied.
# 1. Add point ROIs to the RoiManager; either one per slice or only in slices of
# interest. (Be carefull to set the correct Z/T position when adding the ROI.)
# 2. Run the script.

from net.imglib2.util import Intervals
from net.imglib2.view import Views

from ij.plugin.frame import RoiManager

# Initialize some variables
img = ds.getImgPlus()
rois = RoiManager.getInstance().getRoisAsArray()

images = []

total_x_shift = 0
total_y_shift = 0

# Iterate over regions of interest
for j, (start_roi, end_roi) in enumerate(zip(rois[:-1], rois[1:])):

	# Get Z or T positions defined by the position of the ROIs
	z_start = start_roi.getPosition()
	z_end = end_roi.getPosition()

	# Get X positions
	x_start = start_roi.getContainedPoints()[0].x
	x_end = end_roi.getContainedPoints()[0].x

	# Get Y positions
	y_start = start_roi.getContainedPoints()[0].y
	y_end = end_roi.getContainedPoints()[0].y

	# Calculate the linear translation for each frame in the Z/T axis
	x_shift = float(x_end - x_start) / (z_end - z_start)
	y_shift = float(y_end - y_start) / (z_end - z_start)

	# Iterate over each frame in Z/T
	for i, z in enumerate(range(z_start, z_end)):"Processing frame %i/%i for ROIs #%i and #%i (%i total detected ROIs)" % (z_start + i, z_end, j, j + 1, len(rois)))

		# Compute the translation
		dx = int(x_shift + total_x_shift) * -1
		dy = int(y_shift + total_y_shift) * -1

		total_x_shift += x_shift
		total_y_shift += y_shift

		# Get only the frame cooresponding to the actual Z/T position
		intervals = Intervals.createMinMax(0, 0, z, ds.getWidth() - 1, ds.getHeight() - 1, z)
		single_frame = ops.transform().crop(img, intervals)

		# Pad the frame so the outisde of the image dimensions values are set to 0
		padded_frame = ops.transform().extendZero(single_frame)

		# Do the translation
		translated_frame = ops.transform().translate(padded_frame, [dx, dy])

		# Cleanup
		interval2d = Intervals.createMinMax(0, 0, ds.getWidth() - 1, ds.getHeight() - 1)
		translated_frame = Views.interval(translated_frame, interval2d)
		translated_frame = ops.transform().dropSingletonDimensions(translated_frame)


images ="transform.stackView", [images])
output = datasetService.create(images)