Skip to content

Commit

Permalink
Experimental code to wrap ImageServer
Browse files Browse the repository at this point in the history
  • Loading branch information
petebankhead committed Dec 4, 2022
1 parent 5d3bc86 commit 46bac4f
Showing 1 changed file with 115 additions and 22 deletions.
137 changes: 115 additions & 22 deletions src/main/java/qupath/ext/djl/DjlZoo.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.awt.image.WritableRaster;
import java.io.IOException;
import java.lang.reflect.Type;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -31,6 +32,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -73,6 +75,12 @@
import qupath.lib.analysis.images.SimpleImage;
import qupath.lib.geom.Point2;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.AbstractTileableImageServer;
import qupath.lib.images.servers.GeneratingImageServer;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.images.servers.ImageServerBuilder.ServerBuilder;
import qupath.lib.images.servers.ImageServerMetadata;
import qupath.lib.images.servers.TileRequest;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.PathObjectTools;
import qupath.lib.objects.PathObjects;
Expand Down Expand Up @@ -263,9 +271,12 @@ public static List<Artifact> listInstanceSegmentationModels() throws ModelNotFou
.findFirst()
.orElse(null);
if (preferredTypes == null) {
logger.warn("No supported types found in " + preferredTypes + " -\n"
if (supportedTypes.size() == 1)
preferredTypes = supportedTypes.iterator().next();
logger.warn("No supported types found in " + factoryClass + " -\n"
+ "Please call .builder().setTypes(inputClass, outputClass).build() to specify these directly");
} else
}
if (preferredTypes != null)
builder = builder.setTypes((Class<?>)preferredTypes.getKey(), (Class<?>)preferredTypes.getValue());
} else {
logger.warn("No translatorFactory specified - will try to choose suitable input/output class based on the application.\n"
Expand Down Expand Up @@ -446,7 +457,7 @@ public static Optional<List<PathObject>> detect(ZooModel<Image, DetectedObjects>
double defaultThreshold = 0.5;
var threshold = tryToParseDoubleProperty(model, "threshold", defaultThreshold);
if (threshold != defaultThreshold) {
logger.info("Setting threshold to {} from model properties", threshold);
logger.debug("Setting threshold to {} from model properties", threshold);
}

var server = imageData.getServer();
Expand Down Expand Up @@ -637,7 +648,7 @@ private static Shape getInputHeightWidth(Model model) {
var w = (long)tryToParseDoubleProperty(model, "width", inputWidth);
var h = (long)tryToParseDoubleProperty(model, "height", inputHeight);
if (w != inputWidth || h != inputHeight) {
logger.info("Setting input size to {} x {}", inputWidth, inputHeight);
logger.debug("Setting input size to {} x {}", inputWidth, inputHeight);
inputWidth = w;
inputHeight = h;
}
Expand Down Expand Up @@ -848,22 +859,6 @@ public static List<PathObject> segmentObjects(Predictor<Image, CategoryMask> pre
var map = segmentROIs(predictor, img, request, roiMask, skipBackground);
return map.entrySet().stream().map(e -> createPathObject(creator, e.getValue(), e.getKey())).collect(Collectors.toList());
}


public static BufferedImage imageToImage(Criteria<Image, Image> criteria, BufferedImage img) throws TranslateException, ModelNotFoundException, MalformedModelException, IOException {
try (var model = criteria.loadModel()) {
try (var predictor = model.newPredictor()) {
return imageToImage(predictor, img);
}
}
}


public static BufferedImage imageToImage(Predictor<Image, Image> predictor, BufferedImage img) throws TranslateException {
var image = BufferedImageFactory.getInstance().fromImage(img);
var output = predictor.predict(image);
return toBufferedImage(output);
}


private static PathObject createPathObject(Function<ROI, PathObject> creator, ROI roi, String classification) {
Expand Down Expand Up @@ -991,12 +986,14 @@ public static Classifications classify(Predictor<Image, Classifications> predict

/**
* Generate images using a {@link BigGANTranslator}.
* <p>
* See <a href="https://docs.djl.ai/examples/docs/biggan.html">https://docs.djl.ai/examples/docs/biggan.html</a>
* @param model the model
* @param indices the indices defining what should be in the image
* @return the images generated from the input indices
* @throws TranslateException
*/
public static List<BufferedImage> bigGanGenerate(ZooModel<int[], Image[]> model, int... indices) throws TranslateException {
static List<BufferedImage> bigGanGenerate(ZooModel<int[], Image[]> model, int... indices) throws TranslateException {
if (!(model.getTranslator() instanceof BigGANTranslator)) {
// Log a warning - we can still try, but with lower expectations
logger.warn("Model translater is not an instance of BigGANTranslator");
Expand All @@ -1005,10 +1002,23 @@ public static List<BufferedImage> bigGanGenerate(ZooModel<int[], Image[]> model,
var output = (Image[])predictor.predict(indices);
return Arrays.stream(output).map(i -> toBufferedImage(i)).collect(Collectors.toList());
}

}


/**
* Apply a predictor that takes a {@link BufferedImage} as input and provides another {@link BufferedImage} as output.
* @param predictor
* @param img
* @return
* @throws TranslateException
* @imple
*/
public static BufferedImage imageToImage(Predictor<Image, Image> predictor, BufferedImage img) throws TranslateException {
var image = BufferedImageFactory.getInstance().fromImage(img);
var output = predictor.predict(image);
return toBufferedImage(output);
}

/**
* Try to convert a DJL {@link Image} to a Java {@link BufferedImage}.
* @param image
Expand All @@ -1022,5 +1032,88 @@ public static BufferedImage toBufferedImage(Image image) throws IllegalArgumentE
throw new IllegalArgumentException("Need a java.awt.image.BufferedImage, but found " + wrapped);
}



/**
* Experimental (read: probably-not-very-useful) code to wrap an {@link ImageServer} to apply an image-to-image
* prediction model to the tiles.
* @param model
* @param server
* @return
* @implNote The {@link ImageServer} created here cannot be serialized to JSON; it can only be used temporarily
* within a single QuPath session.
*/
static ImageServer<BufferedImage> wrapImageToImage(ZooModel<Image, Image> model, ImageServer<BufferedImage> server) {
return new DjlPredictionImageServer(server, model);
}


static class DjlPredictionImageServer extends AbstractTileableImageServer implements GeneratingImageServer<BufferedImage> {

private ImageServer<BufferedImage> server;
private ZooModel<Image, Image> model;

private ThreadLocal<Predictor<Image, Image>> predictors;

DjlPredictionImageServer(ImageServer<BufferedImage> server, ZooModel<Image, Image> model) {
this.server = server;
this.model = model;
this.predictors = ThreadLocal.withInitial(() -> model.newPredictor());
var imageHeightWidth = getInputHeightWidth(model);
long tileWidth = imageHeightWidth.size(1) <= 0 ? 512 : imageHeightWidth.size(1);
long tileHeight = imageHeightWidth.size(0) <= 0 ? tileWidth : imageHeightWidth.size(0);
setMetadata(
new ImageServerMetadata.Builder(server.getMetadata())
.preferredTileSize((int)tileWidth, (int)tileHeight)
.build()
);
}

@Override
public Collection<URI> getURIs() {
return server.getURIs();
}

@Override
public String getServerType() {
return "Deep Java Library prediction server";
}

@Override
public ImageServerMetadata getOriginalMetadata() {
return server.getOriginalMetadata();
}

@Override
protected BufferedImage readTile(TileRequest tileRequest) throws IOException {
if (server.isEmptyRegion(tileRequest.getRegionRequest()))
return getEmptyTile(tileRequest.getTileWidth(), tileRequest.getTileHeight());
var img = server.readRegion(tileRequest.getRegionRequest());
try {
return imageToImage(predictors.get(), img);
} catch (TranslateException e) {
throw new IOException(e);
}
}

@Override
protected ServerBuilder<BufferedImage> createServerBuilder() {
throw new UnsupportedOperationException("DjlPredictionImageServer cannot currently be serialized");
}

@Override
protected String createID() {
return UUID.randomUUID().toString();
}

@Override
public void close() throws Exception {
super.close();
model.close();
}


}


}

0 comments on commit 46bac4f

Please sign in to comment.