Scripting the Trainable Weka Segmentation

Revision as of 02:36, 13 October 2015 by Iarganda (talk | contribs)

Scripting is one of the reasons Fiji is so powerful, and the Trainable Segmentation library (that includes the Trainable Weka Segmentation plugin methods) is one of the best examples for scriptable Fiji components.

Getting started

The first thing you need to start scripting the Trainable Segmentation is to know which methods you can use. For that, please have a look at the API of the Trainable Segmentation library, which is available here.

Let's go through the basic commands with examples written in Beanshell:

Initialization

In order to include all the library methods, the easiest (but not elegant) way of doing it is importing the whole library:

import trainableSegmentation.*;

Now we are ready to play. We can open our input image and assign it to a WekaSegmentation object or segmentator:

// input train image
input  = IJ.openImage( "input-grayscale-or-color-image.tif" );
// create Weka Segmentation object
segmentator = new WekaSegmentation( input );

At it is now, the segmentator has default parameters and default classifier. That means that it will use the same features that are set by default in the Trainable Weka Segmentation plugin, 2 classes (named "class 1" and "class 2") and a random forest classifier with 200 trees and 2 random features per node. If we are fine with that, we can now add some labels for our training data and train the classifier based on them.

Adding training samples

There are different ways of adding labels to our data:

1) we can add any type of ROI to any of the existing classes using "addExample":

// add pixels to first class (0) from ROI in slice # 1
segmentator.addExample( 0, new Roi( 10, 10, 50, 50 ), 1 );
// add pixels to second class (1) from ROI in slice # 1
segmentator.addExample( 1, new Roi( 400, 400, 30, 30 ), 1 );

2) add the labels from a binary image, where white pixels belong to one class and black pixels belong to the other class. There are a few methods to do this, for example:

// open binary label image
labels  = IJ.openImage( "binary-labels.tif" );
// for the first slice, add white pixels as labels for class 2 and 
// black pixels as labels for class 1
segmentator.addBinaryData( labels, 0, "class 2", "class 1" );

3) You can also add samples from a new input image and its corresponding labels:

// open new input image
input2  = IJ.openImage( "input-image-2.tif" );
// open corresponding binary label image
labels2  = IJ.openImage( "binary-labels-2.tif" );
// for all slices in input2, add white pixels as labels for class 2 and 
// black pixels as labels for class 1
segmentator.addBinaryData( input2, labels2, "class 2", "class 1" );

4) If you want to balance the number of samples for each class you can do it in a similar way using this other method:

numSamples = 1000;
// for all slices in input2, add 1000 white pixels as labels for class 2 and 
// 1000 black pixels as labels for class 1
segmentator.addRandomBalancedBinaryData( input2, labels2, "class 2", "class 1" , numSamples);

5) You can use all methods available in the API to add labels from a binary image in many differente ways. Please, have a look at them and decided which one fits better your needs.

Training classifier

Once we have training samples for both classes, we are ready to train the classifier of our segmentator:

segmentator.trainClassifier();

Applying classifier (getting results)

Once the classifier is trained (it will be displayed in the Log window), we can apply it to the entire training image and obtain a result in the form of a labeled image or a probability map for each class:

// apply classifier to current training image and get label result 
// (set parameter to true to get probabilities)
segmentator.applyClassifier( false );
// get result (float image)
result = segmentator.getClassifiedImage();

Of course, we might be interested on applying the trained classifier to a complete new 2D image or stack. In that case we use:

// open test image
testImage = IJ.openImage( "test-image.tif" );
// get result (labels float image)
result = segmentator.applyClassifier( testImage );

Save/Load operations

If the classifier you trained is good enough for your purposes, you may want to save it into a file:

// save classifier into a file (.model)
segmentator.saveClassifier( "my-cool-trained-classifier.model" );

... and load it later in another script to apply it on new images:

// load classifier from file
segmentator.loadClassifier( "my-cool-trained-classifier.model" );

You may also want to save the training data into a file you can open later in WEKA:

// save data into a ARFF file
segmentator.saveData( "my-traces-data.arff" );

... or load a file with traces information into the segmentator to use it as part of the training:

// load training data from ARFF file
segmentator.loadTrainingData( "my-traces-data.arff" );

Setting the classifier

By default, the classifier is a multi-threaded implementation of a random forest. You can change it to any other classifier available in the WEKA API. For example, we can use SMO:

import weka.classifiers.functions.SMO;
// create new SMO classifier (default parameters)
classifier = new SMO();
// assign classifier to segmentator
segmentator.setClassifier( classifier );

We might also want to use the default random forest but tune its parameters. In that case, we can write something like this:

import hr.irb.fastRandomForest.FastRandomForest;
// create random forest classifier
rf = new FastRandomForest();
// set number of trees in the forest
rf.setNumTrees( 100 );        
// set number of features per tree (0 for automatic selection)
rf.setNumFeatures( 0 );
// set random seed
rf.setSeed( (new java.util.Random()).nextInt() );
 
// set classifier
segmentator.setClassifier( rf );

Example: define your own features

Here is a little Javascript that makes two features from the Clown example and uses them to train a classifier (see the inline comments for more information):

importClass(Packages.ij.IJ);
importClass(Packages.ij.ImagePlus);
importClass(Packages.ij.ImageStack);
importClass(Packages.ij.gui.PolygonRoi);
importClass(Packages.ij.plugin.Duplicator);
importClass(Packages.ij.process.FloatPolygon);
importClass(Packages.ij.process.StackConverter);
importClass(Packages.trainableSegmentation.FeatureStack);
importClass(Packages.trainableSegmentation.FeatureStackArray);
importClass(Packages.trainableSegmentation.WekaSegmentation);

var image = IJ.openImage(System.getProperty("ij.dir") + "/samples/clown.jpg");
if (image.getStackSize() > 1)
        new StackConverter(image).convertToGray32();
else
        image.setProcessor(image.getProcessor().convertToFloat());

var duplicator = new Duplicator();

// process the image into different stacks, one per feature:

var smoothed = duplicator.run(image);
IJ.run(smoothed, "Gaussian Blur...", "radius=20");

var medianed = duplicator.run(image);
IJ.run(medianed, "Median...", "radius=10");

// add new feature here (1/2)

// the FeatureStackArray contains a FeatureStack for every slice in our original image
var featuresArray = new FeatureStackArray(image.getStackSize(), 1, 16, false,
        1, 19, null);

// turn the list of stacks into FeatureStack instances, one per original
// slice. Each FeatureStack contains exactly one slice per feature.
for (var slice = 1; slice <= image.getStackSize(); slice++) {
        var stack = new ImageStack(image.getWidth(), image.getHeight());
        stack.addSlice("smoothed", smoothed.getStack().getProcessor(slice));
        stack.addSlice("medianed", medianed.getStack().getProcessor(slice));

        // add new feature here (2/2) and do not forget to add it with a
        // unique slice label!

        // create empty feature stack
        var features = new FeatureStack( stack.getWidth(), stack.getHeight(), false );
        // set my features to the feature stack
        features.setStack( stack );
        // put my feature stack into the array
        featuresArray.set(features, slice - 1);
        featuresArray.setEnabledFeatures(features.getEnabledFeatures());
}

var wekaSegmentation = new WekaSegmentation(image);
wekaSegmentation.setFeatureStackArray(featuresArray);

// set examples for class 1 (= foreground) and 0 (= background))
function addExample(classNum, slice, xArray, yArray) {
        var polygon = new FloatPolygon(xArray, yArray);
        var roi = new PolygonRoi(polygon, PolygonRoi.FREELINE);
        IJ.log("roi: " + roi);
        wekaSegmentation.addExample(classNum, roi, slice);
}

/*
 * generate these with the macro:

        getSelectionCoordinates(x, y);

        print('['); Array.print(x); print('],");
        print('['); Array.print(y); print(']");
 */
addExample(1, 1,
        [ 82,85,85,86,87,87,87,88,88,88,88,88,88,88,88,86,86,84,83,82,81,
          80,80,78,76,75,74,74,73,72,71,70,70,68,65,63,62,60,58,57,55,55,
          54,53,51,50,49,49,49,51,52,53,54,55,55,56,56 ],
        [ 141,137,136,134,133,132,130,129,128,127,126,125,124,123,122,121,
          120,119,118,118,116,116,115,115,114,114,113,112,111,111,111,111,
          110,110,110,110,111,112,113,114,114,115,116,117,118,119,119,120,
          121,123,125,126,128,128,129,129,130 ]
);
addExample(0, 1,
        [ 167,165,163,161,158,157,157,157,157,157,157,157,158 ],
        [ 30,29,29,29,29,29,28,26,25,24,23,22,21 ]
);

// train classifier
if (!wekaSegmentation.trainClassifier())
        throw new RuntimeException("Uh oh! No training today.");

output = wekaSegmentation.applyClassifier(image);
output.show();

Example: define binary labels programmatically

Here is a simple script in Beanshell doing the following:

  1. It takes one image (2D or stack) as training input image and a binary image as the corresponding labels.
  2. Train a classifier (in this case a random forest, but it can be changed) based on randomly selected pixels of the training image. The number of samples (pixels to use for training) is also a parameter, and it will be the same for each class.
  3. Apply the trained classifier to a test image (2D or stack).
import ij.*;
import ij.process.*;
import trainableSegmentation.*;
import hr.irb.fastRandomForest.*;

// training input image (it could be a stack or a single 2D image)
image = IJ.openImage("train-volume.tif");
// corresponding binary labels
labels = IJ.openImage("train-labels.tif");

// create Weka segmentator
seg = new WekaSegmentation(image);

// number of samples to use per slice
nSamplesToUse = 2000;


// Classifier
// In this case we use a Fast Random Forest
rf = new FastRandomForest();
// Number of trees in the forest
rf.setNumTrees(100);		
// Number of features per tree
rf.setNumFeatures(0);
// Seed
rf.setSeed( (new java.util.Random()).nextInt() );

// set classifier
seg.setClassifier(rf);

// Parameters 
// membrane patch size
seg.setMembranePatchSize(11);
// maximum radius of the filters
seg.setMaximumSigma(16.0f);

// Selected attributes
enableFeatures = new boolean[]{
			true, 	/* Gaussian_blur */
			true, 	/* Sobel_filter */
			true, 	/* Hessian */
			true, 	/* Difference_of_gaussians */
			true, 	/* Membrane_projections */
			false, 	/* Variance */
			false, 	/* Mean */
			false, 	/* Minimum */
			false, 	/* Maximum */
			false, 	/* Median */
			false,	/* Anisotropic_diffusion */
			false, 	/* Bilateral */
			false, 	/* Lipschitz */
			false, 	/* Kuwahara */
			false,	/* Gabor */
			false, 	/* Derivatives */
			false, 	/* Laplacian */
			false,	/* Structure */
			false,	/* Entropy */
			false	/* Neighbors */
};

// Enable features in the segmentator
seg.setEnabledFeatures( enableFeatures );

// Add balanced labels
seg.addRandomBalancedBinaryData(image, labels, "class 2", "class 1", nSamplesToUse);

// Train classifier
seg.trainClassifier();


// Un-comment next 7 lines to segment and display training image
/*
// Apply classifier to current image
seg.applyClassifier( true );

// Display classified image
prob = seg.getClassifiedImage();
prob.setTitle( "Probability maps of train image" );
prob.show();
*/

// Open test image
image = IJ.openImage("test-volume.tif");

// Apply trained classifier to test image and get probabilities
prob = seg.applyClassifier(image, 0, true );

// Display results
prob.setTitle( "Probability maps of test image" );
prob.show();

image.show();

IJ.log("---");