diff --git a/.gitignore b/.gitignore index 3dd39b4..1f88996 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.idea/ +*.iml /target/ .DS_Store .classpath diff --git a/pom.xml b/pom.xml index 117b789..afd035b 100644 --- a/pom.xml +++ b/pom.xml @@ -185,6 +185,75 @@ cisd jhdf5 + + ch.qos.logback + logback-classic + + + com.fasterxml.jackson.core + jackson-databind + 2.6.5 + + + commons-logging + commons-logging + 1.1.1 + + + io.netty + netty-all + 4.1.17.Final + + + io.netty + netty + 3.9.9.Final + + + org.apache.spark + spark-core_2.11 + 2.4.7 + + + netty-all + io.netty + + + + jcl-over-slf4j + org.slf4j + + + jul-to-slf4j + org.slf4j + + + slf4j-log4j12 + org.slf4j + + + aopalliance-repackaged + org.glassfish.hk2.external + + + javax.inject + org.glassfish.hk2.external + + + jersey-client + org.glassfish.jersey.core + + + lz4 + net.jpountz.lz4 + + + jets3t + net.java.dev.jets3t + + + + net.preibisch diff --git a/src/main/java/align/PairwiseSIFT.java b/src/main/java/align/PairwiseSIFT.java index 293194f..6a7dd87 100644 --- a/src/main/java/align/PairwiseSIFT.java +++ b/src/main/java/align/PairwiseSIFT.java @@ -1,23 +1,5 @@ package align; -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; -import java.util.concurrent.atomic.AtomicInteger; - -import org.janelia.saalfeldlab.n5.DataType; -import org.janelia.saalfeldlab.n5.GzipCompression; -import org.janelia.saalfeldlab.n5.N5FSReader; -import org.janelia.saalfeldlab.n5.N5FSWriter; -import org.joml.Math; - import data.STData; import data.STDataStatistics; import data.STDataUtils; @@ -27,15 +9,13 @@ import imglib2.ImgLib2Util; import io.N5IO; import io.Path; +import java.util.function.Supplier; import mpicbg.ij.FeatureTransform; import mpicbg.ij.SIFT; import mpicbg.ij.util.Util; import mpicbg.imagefeatures.Feature; import mpicbg.imagefeatures.FloatArray2DSIFT; -import mpicbg.models.AbstractModel; import mpicbg.models.Affine2D; -import mpicbg.models.AffineModel2D; -import mpicbg.models.InterpolatedAffineModel2D; import mpicbg.models.Model; import mpicbg.models.NotEnoughDataPointsException; import mpicbg.models.Point; @@ -48,232 +28,455 @@ import net.imglib2.realtransform.AffineTransform2D; import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.util.Intervals; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.janelia.saalfeldlab.n5.DataType; +import org.janelia.saalfeldlab.n5.GzipCompression; +import org.janelia.saalfeldlab.n5.N5FSReader; +import org.janelia.saalfeldlab.n5.N5FSWriter; +import org.joml.Math; import util.Threads; -public class PairwiseSIFT -{ - public static class SIFTParam - { - public SIFTParam() - { - this.sift.fdSize = 8; - this.sift.fdBins = 8; - this.sift.steps = 10; - this.rod = 0.90f; - this.sift.minOctaveSize = 128; - } - - public SIFTParam( final int fdSize, final int fdBins, final int steps, final float rod, final int minOctaveSize ) - { - this.sift.fdSize = fdSize; - this.sift.fdBins = fdBins; - this.sift.steps = steps; - this.rod = rod; - this.sift.minOctaveSize = minOctaveSize; - } - - final public FloatArray2DSIFT.Param sift = new FloatArray2DSIFT.Param(); - - /** - * Closest/next closest neighbour distance ratio - */ - public float rod; - } - - static public void matchFeatures( - final Collection< Feature > fs1, - final Collection< Feature > fs2, - final List< PointMatch > matches, - final float rod ) - { - for ( final Feature f1 : fs1 ) - { - Feature best = null; - double best_d = Double.MAX_VALUE; - double second_best_d = Double.MAX_VALUE; - - for ( final Feature f2 : fs2 ) - { - final double d = f1.descriptorDistance( f2 ); - if ( d < best_d ) - { - second_best_d = best_d; - best_d = d; - best = f2; - } - else if ( d < second_best_d ) - second_best_d = d; - } - - if ( best != null && second_best_d < Double.MAX_VALUE && best_d / second_best_d < rod ) - matches.add( - new PointMatch( - new Point( - new double[] { f1.location[ 0 ], f1.location[ 1 ] } ), - new Point( - new double[] { best.location[ 0 ], best.location[ 1 ] } ), - best_d / second_best_d ) ); - } - - // now remove ambiguous matches - for ( int i = 0; i < matches.size(); ) - { - boolean amb = false; - final PointMatch m = matches.get( i ); - final double[] m_p2 = m.getP2().getL(); - for ( int j = i + 1; j < matches.size(); ) - { - final PointMatch n = matches.get( j ); - final double[] n_p2 = n.getP2().getL(); - if ( m_p2[ 0 ] == n_p2[ 0 ] && m_p2[ 1 ] == n_p2[ 1 ] ) - { - amb = true; - matches.remove( j ); - } - else ++j; - } - if ( amb ) - matches.remove( i ); - else ++i; - } - } - public static List< PointMatch > extractCandidates( final ImageProcessor ip1, final ImageProcessor ip2, final String gene, final SIFTParam p ) - { - final List< Feature > fs1 = new ArrayList< Feature >(); - final List< Feature > fs2 = new ArrayList< Feature >(); - - final FloatArray2DSIFT sift = new FloatArray2DSIFT( p.sift ); - - final SIFT ijSIFT = new SIFT( sift ); - ijSIFT.extractFeatures( ip1, fs1 ); - ijSIFT.extractFeatures( ip2, fs2 ); - - final List< PointMatch > candidates = new ArrayList< PointMatch >(); - FeatureTransform.matchFeatures( fs1, fs2, candidates, p.rod ); - - final List< PointMatch > candidatesST = new ArrayList< PointMatch >(); - for ( final PointMatch pm : candidates ) - candidatesST.add( - new PointMatch( - new PointST( pm.getP1().getL(), gene ), - new PointST( pm.getP2().getL(), gene ) )); - - return candidatesST; - } - - public static ArrayList< PointMatch > consensus( final List< PointMatch > candidates, final Model< ? > model, final int minNumInliers, final double maxEpsilon ) - { - final ArrayList< PointMatch > inliers = new ArrayList< PointMatch >(); - - boolean modelFound; - - try - { - modelFound = model.filterRansac( - candidates, - inliers, - 10000, - maxEpsilon, - 0.1f, //p.minInlierRatio, - minNumInliers, - 3f ); - } - catch ( final NotEnoughDataPointsException e ) - { - modelFound = false; - } - - if ( modelFound ) - PointMatch.apply( inliers, model ); - else - inliers.clear(); - - return inliers; - } - - public static void visualizeInliers( final ImagePlus imp1, final ImagePlus imp2, final List< PointMatch > inliers ) - { - if ( inliers.size() > 0 ) - { - final ArrayList< Point > p1 = new ArrayList< Point >(); - final ArrayList< Point > p2 = new ArrayList< Point >(); - - PointMatch.sourcePoints( inliers, p1 ); - PointMatch.targetPoints( inliers, p2 ); - - imp1.setRoi( Util.pointsToPointRoi( p1 ) ); - imp2.setRoi( Util.pointsToPointRoi( p2 ) ); - } - } - - public static < M extends Affine2D & Model, N extends Affine2D & Model > void pairwiseSIFT( - final STData stDataA, - final String stDataAname, - final STData stDataB, - final String stDataBname, - final M modelPairwise, - final N modelGlobal, - final File n5File, - final List< String > genesToTest, - final SIFTParam p, - final double scale, - final double smoothnessFactor, - final double maxEpsilon, - final int minNumInliers, - final int minNumInliersPerGene, - final boolean saveResult, - final boolean visualizeResult, - final int numThreads ) throws IOException - { - final AffineTransform2D tS = new AffineTransform2D(); - tS.scale( scale ); - - final Interval interval = STDataUtils.getCommonInterval( stDataA, stDataB ); - final Interval finalInterval = Intervals.expand( ImgLib2Util.transformInterval( interval, tS ), 100 ); - - final List< PointMatch > allCandidates = new ArrayList<>(); - - final List< Callable< List< PointMatch > > > tasks = new ArrayList<>(); - final AtomicInteger nextGene = new AtomicInteger(); - - for ( int threadNum = 0; threadNum < numThreads; ++threadNum ) - { - tasks.add( () -> - { - final List< PointMatch > allPerGeneInliers = new ArrayList<>(); - - for ( int g = nextGene.getAndIncrement(); g < genesToTest.size(); g = nextGene.getAndIncrement() ) - { - final String gene = genesToTest.get( g ); - //System.out.println( "current gene: " + gene ); - - final RandomAccessibleInterval imgA = AlignTools.display( stDataA, new STDataStatistics( stDataA ), gene, finalInterval, tS, null, smoothnessFactor ); - final RandomAccessibleInterval imgB = AlignTools.display( stDataB, new STDataStatistics( stDataB ), gene, finalInterval, tS, null, smoothnessFactor ); - - final ImagePlus impA = ImageJFunctions.wrapFloat( imgA, new RealFloatConverter<>(), "A_" + gene); - final ImagePlus impB = ImageJFunctions.wrapFloat( imgB, new RealFloatConverter<>(), "B_" + gene ); - - final List< PointMatch > matchesAB = extractCandidates(impA.getProcessor(), impB.getProcessor(), gene, p ); - final List< PointMatch > matchesBA = extractCandidates(impB.getProcessor(), impA.getProcessor(), gene, p ); - - //System.out.println( gene + " = " + matchesAB.size() ); - //System.out.println( gene + " = " + matchesBA.size() ); - - if ( matchesAB.size() == 0 && matchesBA.size() == 0 ) - continue; - - final List< PointMatch > candidatesTmp = new ArrayList<>(); - - if ( matchesBA.size() > matchesAB.size() ) - PointMatch.flip( matchesBA, candidatesTmp ); - else - candidatesTmp.addAll( matchesAB ); - - //final List< PointMatch > inliersTmp = consensus( candidatesTmp, new RigidModel2D(), minNumInliersPerGene, maxEpsilon*scale ); - //System.out.println( "remaining points" ); - - //for ( final PointMatch pm:inliersTmp ) - // System.out.println( pm.getWeight() ); +import java.io.File; +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +public class PairwiseSIFT { + + public static class SIFTParam implements Serializable { + public SIFTParam() { + this.sift.fdSize = 8; + this.sift.fdBins = 8; + this.sift.steps = 10; + this.rod = 0.90f; + this.sift.minOctaveSize = 128; + } + + public SIFTParam copy() { + SIFTParam params = new SIFTParam(this.sift.fdSize, this.sift.fdBins, this.sift.steps, this.rod, this.sift.minOctaveSize); + return params; + } + + public SIFTParam(final int fdSize, final int fdBins, final int steps, final float rod, final int minOctaveSize) { + this.sift.fdSize = fdSize; + this.sift.fdBins = fdBins; + this.sift.steps = steps; + this.rod = rod; + this.sift.minOctaveSize = minOctaveSize; + } + + final public FloatArray2DSIFT.Param sift = new FloatArray2DSIFT.Param(); + + /** + * Closest/next closest neighbour distance ratio + */ + public float rod; + } + + static public void matchFeatures( + final Collection fs1, + final Collection fs2, + final List matches, + final float rod) { + for (final Feature f1 : fs1) { + Feature best = null; + double best_d = Double.MAX_VALUE; + double second_best_d = Double.MAX_VALUE; + + for (final Feature f2 : fs2) { + final double d = f1.descriptorDistance(f2); + if (d < best_d) { + second_best_d = best_d; + best_d = d; + best = f2; + } else if (d < second_best_d) + second_best_d = d; + } + + if (best != null && second_best_d < Double.MAX_VALUE && best_d / second_best_d < rod) + matches.add( + new PointMatch( + new Point( + new double[]{f1.location[0], f1.location[1]}), + new Point( + new double[]{best.location[0], best.location[1]}), + best_d / second_best_d)); + } + + // now remove ambiguous matches + for (int i = 0; i < matches.size(); ) { + boolean amb = false; + final PointMatch m = matches.get(i); + final double[] m_p2 = m.getP2().getL(); + for (int j = i + 1; j < matches.size(); ) { + final PointMatch n = matches.get(j); + final double[] n_p2 = n.getP2().getL(); + if (m_p2[0] == n_p2[0] && m_p2[1] == n_p2[1]) { + amb = true; + matches.remove(j); + } else ++j; + } + if (amb) + matches.remove(i); + else ++i; + } + } + + public static List extractCandidates(final ImageProcessor ip1, final ImageProcessor ip2, final String gene, final SIFTParam p) { + final List fs1 = new ArrayList(); + final List fs2 = new ArrayList(); + + final FloatArray2DSIFT sift = new FloatArray2DSIFT(p.sift); + + final SIFT ijSIFT = new SIFT(sift); + ijSIFT.extractFeatures(ip1, fs1); + ijSIFT.extractFeatures(ip2, fs2); + + final List candidates = new ArrayList(); + FeatureTransform.matchFeatures(fs1, fs2, candidates, p.rod); + + final List candidatesST = new ArrayList(); + for (final PointMatch pm : candidates) + candidatesST.add( + new PointMatch( + new PointST(pm.getP1().getL(), gene), + new PointST(pm.getP2().getL(), gene))); + + return candidatesST; + } + + public static ArrayList consensus(final List candidates, final Model model, final int minNumInliers, final double maxEpsilon) { + final ArrayList inliers = new ArrayList(); + + boolean modelFound; + + try { + modelFound = model.filterRansac( + candidates, + inliers, + 10000, + maxEpsilon, + 0.1f, //p.minInlierRatio, + minNumInliers, + 3f); + } catch (final NotEnoughDataPointsException e) { + modelFound = false; + } + + if (modelFound) + PointMatch.apply(inliers, model); + else + inliers.clear(); + + return inliers; + } + + public static void visualizeInliers(final ImagePlus imp1, final ImagePlus imp2, final List inliers) { + if (inliers.size() > 0) { + final ArrayList p1 = new ArrayList(); + final ArrayList p2 = new ArrayList(); + + PointMatch.sourcePoints(inliers, p1); + PointMatch.targetPoints(inliers, p2); + + imp1.setRoi(Util.pointsToPointRoi(p1)); + imp2.setRoi(Util.pointsToPointRoi(p2)); + } + } + + public static & Model, N extends Affine2D & Model, SM extends Supplier & Serializable> void sparkPairwiseSIFT( + final String inputPath, + final String stDataA_name, + final String stDataB_name, + final SM modelPairwise, + final N modelGlobal, + final List genesToTest, + final SIFTParam p, + final double scal, + final double smFactor, + final double mxEpsilon, + final int minNumInliers, + final int minInliersPerGene, + final boolean saveResult, + final boolean visualizeResult, + final JavaSparkContext sc) throws IOException { + + final AffineTransform2D tS = new AffineTransform2D(); + tS.scale(scal); + + final double[] transform = tS.getRowPackedCopy(); + + final String input = inputPath; + final String stDataAname = stDataA_name; + final String stDataBname = stDataB_name; + final double smoothnessFactor = smFactor; + final double scale = scal; + final SIFTParam params = p.copy(); + final int minNumInliersPerGene = minInliersPerGene; + final double maxEpsilon = mxEpsilon; + + final JavaRDD> rdd = sc.parallelize(genesToTest).map((Function>) gene -> { + final List allPerGeneInliers = new ArrayList<>(); + + System.out.println("current gene: " + gene); + System.out.println(stDataAname); + AffineTransform2D transform2D = new AffineTransform2D(); + transform2D.set(transform); + + final File n5File = new File(inputPath); + final N5FSWriter n5 = N5IO.openN5write(n5File); + final STData stDataA = N5IO.readN5(n5, stDataAname); + final STData stDataB = N5IO.readN5(n5, stDataBname); + System.out.println("Got file n5"); + final Interval interval = STDataUtils.getCommonInterval(stDataA, stDataB); + final Interval finalInterval = Intervals.expand(ImgLib2Util.transformInterval(interval, transform2D), 100); + + final RandomAccessibleInterval imgA = AlignTools.display(stDataA, new STDataStatistics(stDataA), gene, finalInterval, transform2D, null, smoothnessFactor); + final RandomAccessibleInterval imgB = AlignTools.display(stDataB, new STDataStatistics(stDataB), gene, finalInterval, transform2D, null, smoothnessFactor); + + final ImagePlus impA = ImageJFunctions.wrapFloat(imgA, new RealFloatConverter<>(), "A_" + gene); + final ImagePlus impB = ImageJFunctions.wrapFloat(imgB, new RealFloatConverter<>(), "B_" + gene); + + final List matchesAB = extractCandidates(impA.getProcessor(), impB.getProcessor(), gene, params); + final List matchesBA = extractCandidates(impB.getProcessor(), impA.getProcessor(), gene, params); + + + System.out.println( gene + " = " + matchesAB.size() ); + System.out.println( gene + " = " + matchesBA.size() ); + + if (matchesAB.size() == 0 && matchesBA.size() == 0) + return allPerGeneInliers; + + final List candidatesTmp = new ArrayList<>(); + + if (matchesBA.size() > matchesAB.size()) + PointMatch.flip(matchesBA, candidatesTmp); + else + candidatesTmp.addAll(matchesAB); + + //final List< PointMatch > inliersTmp = consensus( candidatesTmp, new RigidModel2D(), minNumInliersPerGene, maxEpsilon*scale ); + //System.out.println( "remaining points" ); + + //for ( final PointMatch pm:inliersTmp ) + // System.out.println( pm.getWeight() ); + /* + if ( gene.equals("Ckb")) + { + final List< PointMatch > inliersTmp = consensus( candidatesTmp, new RigidModel2D(), minNumInliersPerGene, maxEpsilon*scale ); + if ( inliersTmp.size() > minNumInliersPerGene ) + { + impA.show();impA.resetDisplayRange(); + impB.show();impB.resetDisplayRange(); + visualizeInliers( impA, impB, inliersTmp ); + } + } */ + + // adjust the locations to the global coordinate system + // and store the gene name it came from + for (final PointMatch pm : candidatesTmp) { + final Point p1 = pm.getP1(); + final Point p2 = pm.getP2(); + + for (int d = 0; d < finalInterval.numDimensions(); ++d) { + p1.getL()[d] = p1.getW()[d] = (p1.getL()[d] + finalInterval.min(d)) / scale; + p2.getL()[d] = p2.getW()[d] = (p2.getL()[d] + finalInterval.min(d)) / scale; + } + } + + // prefilter the candidates + //new InterpolatedAffineModel2D<>( new AffineModel2D(), new RigidModel2D(), 0.1 );//new RigidModel2D(); + final List inliers = consensus(candidatesTmp, modelPairwise.get(), minNumInliersPerGene, maxEpsilon); + + // reset world coordinates & compute error + double error = Double.NaN, maxError = Double.NaN, minError = Double.NaN; + if (inliers.size() > 0) { + error = 0; + minError = Double.MAX_VALUE; + maxError = -Double.MAX_VALUE; + + for (final PointMatch pm : inliers) { + final double dist = Point.distance(pm.getP1(), pm.getP2()); + error += dist; + maxError = Math.max(maxError, dist); + minError = Math.min(minError, dist); + } + + error /= (double) inliers.size(); + } + + for (final PointMatch pm : candidatesTmp) { + final Point p1 = pm.getP1(); + final Point p2 = pm.getP2(); + + for (int d = 0; d < finalInterval.numDimensions(); ++d) { + p1.getW()[d] = p1.getL()[d]; + p2.getW()[d] = p2.getL()[d]; + } + } + + if (inliers.size() > 0) { + allPerGeneInliers.addAll(inliers); + System.out.println(stDataAname + "-" + stDataBname + ": " + inliers.size() + "/" + candidatesTmp.size() + ", " + minError + "/" + error + "/" + maxError + ", " + ((PointST) inliers.get(0).getP1()).getGene()); + //System.out.println( ki + "-" + kj + ": " + inliers.size() + "/" + candidatesTmp.size() + ", " + ((PointST)inliers.get( 0 ).getP1()).getGene() + ", " ); + //GlobalOpt.visualizePair(stDataA, stDataB, new AffineTransform2D(), GlobalOpt.modelToAffineTransform2D( model ).inverse() ).setTitle( gene +"_" + inliers.size() );; + } + + + return allPerGeneInliers; + }); + + System.out.println("Got all genes"); + + final List allCandidates = rdd.collect().stream().flatMap(List::stream).collect(Collectors.toList()); + + + //final InterpolatedAffineModel2D model = new InterpolatedAffineModel2D<>( new AffineModel2D(), new RigidModel2D(), 0.1 );//new RigidModel2D(); +// final RigidModel2D modelGlobal = Serializers.deserializeRigidModel2D(serializedModelGlobal); + final ArrayList inliers = consensus(allCandidates, modelGlobal, minNumInliers, maxEpsilon); + + // the model that maps J to I + System.out.println(stDataAname + "\t" + stDataBname + "\t" + inliers.size() + "\t" + allCandidates.size() + "\t" + AlignTools.modelToAffineTransform2D(modelGlobal).inverse()); + + if (saveResult && inliers.size() >= minNumInliers) { + final HashSet genes = new HashSet(); + for (final PointMatch pm : inliers) + genes.add(((PointST) pm.getP1()).getGene()); + final File n5File = new File(input); + final N5FSWriter n5 = N5IO.openN5write(n5File); + final String pairwiseGroupName = n5.groupPath("/", "matches", stDataAname + "-" + stDataBname); + if (n5.exists(pairwiseGroupName)) + n5.remove(pairwiseGroupName); + n5.createDataset( + pairwiseGroupName, + new long[]{1}, + new int[]{1}, + DataType.OBJECT, + new GzipCompression()); + + n5.setAttribute(pairwiseGroupName, "stDataAname", stDataAname); + n5.setAttribute(pairwiseGroupName, "stDataBname", stDataBname); + n5.setAttribute(pairwiseGroupName, "inliers", inliers.size()); + n5.setAttribute(pairwiseGroupName, "candidates", allCandidates.size()); + n5.setAttribute(pairwiseGroupName, "genes", genes); + + n5.writeSerializedBlock( + inliers, + pairwiseGroupName, + n5.getDatasetAttributes(pairwiseGroupName), + new long[]{0}); + } + + if (visualizeResult && inliers.size() >= minNumInliers) { + final File n5File = new File(input); + final N5FSWriter n5 = N5IO.openN5write(n5File); + final STData stDataA = N5IO.readN5(n5, stDataAname); + final STData stDataB = N5IO.readN5(n5, stDataBname); + ImagePlus rendered = AlignTools.visualizePair( + stDataA, stDataB, + new AffineTransform2D(), + AlignTools.modelToAffineTransform2D(modelGlobal).inverse(), + smoothnessFactor); + rendered.setTitle(stDataAname + "-" + stDataBname + "-inliers-" + inliers.size() + " (" + AlignTools.defaultGene + ")"); + } + + // compute errors + // reset world coordinates & compute error + double error = Double.NaN, maxError = Double.NaN, minError = Double.NaN; + if (inliers.size() > 0) { + error = 0; + minError = Double.MAX_VALUE; + maxError = -Double.MAX_VALUE; + + for (final PointMatch pm : inliers) { + pm.apply(modelGlobal); + final double dist = Point.distance(pm.getP1(), pm.getP2()); + error += dist; + maxError = Math.max(maxError, dist); + minError = Math.min(minError, dist); + } + + error /= (double) inliers.size(); + } + + System.out.println("errors: " + minError + "/" + error + "/" + maxError); + } + + public static & Model, N extends Affine2D & Model> void pairwiseSIFT( + final STData stDataA, + final String stDataAname, + final STData stDataB, + final String stDataBname, + final M modelPairwise, + final N modelGlobal, + final File n5File, + final List genesToTest, + final SIFTParam p, + final double scale, + final double smoothnessFactor, + final double maxEpsilon, + final int minNumInliers, + final int minNumInliersPerGene, + final boolean saveResult, + final boolean visualizeResult, + final int numThreads) throws IOException { + final AffineTransform2D tS = new AffineTransform2D(); + tS.scale(scale); + + final Interval interval = STDataUtils.getCommonInterval(stDataA, stDataB); + final Interval finalInterval = Intervals.expand(ImgLib2Util.transformInterval(interval, tS), 100); + + final List allCandidates = new ArrayList<>(); + + final List>> tasks = new ArrayList<>(); + final AtomicInteger nextGene = new AtomicInteger(); + + for (int threadNum = 0; threadNum < numThreads; ++threadNum) { + tasks.add(() -> + { + final List allPerGeneInliers = new ArrayList<>(); + + for (int g = nextGene.getAndIncrement(); g < genesToTest.size(); g = nextGene.getAndIncrement()) { + final String gene = genesToTest.get(g); + //System.out.println( "current gene: " + gene ); + + final RandomAccessibleInterval imgA = AlignTools.display(stDataA, new STDataStatistics(stDataA), gene, finalInterval, tS, null, smoothnessFactor); + final RandomAccessibleInterval imgB = AlignTools.display(stDataB, new STDataStatistics(stDataB), gene, finalInterval, tS, null, smoothnessFactor); + + final ImagePlus impA = ImageJFunctions.wrapFloat(imgA, new RealFloatConverter<>(), "A_" + gene); + final ImagePlus impB = ImageJFunctions.wrapFloat(imgB, new RealFloatConverter<>(), "B_" + gene); + + final List matchesAB = extractCandidates(impA.getProcessor(), impB.getProcessor(), gene, p); + final List matchesBA = extractCandidates(impB.getProcessor(), impA.getProcessor(), gene, p); + + //System.out.println( gene + " = " + matchesAB.size() ); + //System.out.println( gene + " = " + matchesBA.size() ); + + if (matchesAB.size() == 0 && matchesBA.size() == 0) + continue; + + final List candidatesTmp = new ArrayList<>(); + + if (matchesBA.size() > matchesAB.size()) + PointMatch.flip(matchesBA, candidatesTmp); + else + candidatesTmp.addAll(matchesAB); + + //final List< PointMatch > inliersTmp = consensus( candidatesTmp, new RigidModel2D(), minNumInliersPerGene, maxEpsilon*scale ); + //System.out.println( "remaining points" ); + + //for ( final PointMatch pm:inliersTmp ) + // System.out.println( pm.getWeight() ); /* if ( gene.equals("Ckb")) { @@ -286,160 +489,145 @@ public static < M extends Affine2D & Model, N extends Affine2D & Model< } } */ - // adjust the locations to the global coordinate system - // and store the gene name it came from - for ( final PointMatch pm : candidatesTmp ) - { - final Point p1 = pm.getP1(); - final Point p2 = pm.getP2(); - - for ( int d = 0; d < finalInterval.numDimensions(); ++d ) - { - p1.getL()[ d ] = p1.getW()[ d ] = ( p1.getL()[ d ] + finalInterval.min( d ) ) / scale; - p2.getL()[ d ] = p2.getW()[ d ] = ( p2.getL()[ d ] + finalInterval.min( d ) ) / scale; - } - } - - // prefilter the candidates - final M model = modelPairwise.copy();//new InterpolatedAffineModel2D<>( new AffineModel2D(), new RigidModel2D(), 0.1 );//new RigidModel2D(); - final List< PointMatch > inliers = consensus( candidatesTmp, model, minNumInliersPerGene, maxEpsilon ); - - // reset world coordinates & compute error - double error = Double.NaN, maxError = Double.NaN, minError = Double.NaN; - if ( inliers.size() > 0 ) - { - error = 0; - minError = Double.MAX_VALUE; - maxError = -Double.MAX_VALUE; - - for ( final PointMatch pm : inliers ) - { - final double dist = Point.distance(pm.getP1(), pm.getP2()); - error += dist; - maxError = Math.max( maxError, dist ); - minError = Math.min( minError, dist ); - } - - error /= (double)inliers.size(); - } - - for ( final PointMatch pm : candidatesTmp ) - { - final Point p1 = pm.getP1(); - final Point p2 = pm.getP2(); - - for ( int d = 0; d < finalInterval.numDimensions(); ++d ) - { - p1.getW()[ d ] = p1.getL()[ d ]; - p2.getW()[ d ] = p2.getL()[ d ]; - } - } - - if ( inliers.size() > 0 ) - { - allPerGeneInliers.addAll( inliers ); - System.out.println( stDataAname + "-" + stDataBname + ": " + inliers.size() + "/" + candidatesTmp.size() + ", " + minError + "/" + error + "/" + maxError + ", " + ((PointST)inliers.get( 0 ).getP1()).getGene() ); - //System.out.println( ki + "-" + kj + ": " + inliers.size() + "/" + candidatesTmp.size() + ", " + ((PointST)inliers.get( 0 ).getP1()).getGene() + ", " ); - //GlobalOpt.visualizePair(stDataA, stDataB, new AffineTransform2D(), GlobalOpt.modelToAffineTransform2D( model ).inverse() ).setTitle( gene +"_" + inliers.size() );; - } - } - - return allPerGeneInliers; - }); - } - - final ExecutorService service = Threads.createFixedExecutorService( numThreads ); - - try - { - final List< Future< List< PointMatch > > > futures = service.invokeAll( tasks ); - for ( final Future< List< PointMatch > > future : futures ) - allCandidates.addAll( future.get() ); - } - catch ( final InterruptedException | ExecutionException e ) - { - e.printStackTrace(); - throw new RuntimeException( e ); - } - - service.shutdown(); - - //final InterpolatedAffineModel2D model = new InterpolatedAffineModel2D<>( new AffineModel2D(), new RigidModel2D(), 0.1 );//new RigidModel2D(); - //final RigidModel2D model = new RigidModel2D(); - final ArrayList< PointMatch > inliers = consensus( allCandidates, modelGlobal, minNumInliers, maxEpsilon ); - - // the model that maps J to I - System.out.println( stDataAname + "\t" + stDataBname + "\t" + inliers.size() + "\t" + allCandidates.size() + "\t" + AlignTools.modelToAffineTransform2D( modelGlobal ).inverse() ); - - if ( saveResult && inliers.size() >= minNumInliers ) - { - final HashSet< String > genes = new HashSet< String >(); - for ( final PointMatch pm : inliers ) - genes.add( ((PointST)pm.getP1()).getGene() ); - - final N5FSWriter n5 = N5IO.openN5write( n5File ); - final String pairwiseGroupName = n5.groupPath( "/", "matches", stDataAname + "-" + stDataBname ); - if (n5.exists(pairwiseGroupName)) - n5.remove( pairwiseGroupName ); - n5.createDataset( - pairwiseGroupName, - new long[] {1}, - new int[] {1}, - DataType.OBJECT, - new GzipCompression()); - - n5.setAttribute( pairwiseGroupName, "stDataAname", stDataAname ); - n5.setAttribute( pairwiseGroupName, "stDataBname", stDataBname ); - n5.setAttribute( pairwiseGroupName, "inliers", inliers.size() ); - n5.setAttribute( pairwiseGroupName, "candidates", allCandidates.size() ); - n5.setAttribute( pairwiseGroupName, "genes", genes ); - - n5.writeSerializedBlock( - inliers, - pairwiseGroupName, - n5.getDatasetAttributes( pairwiseGroupName ), - new long[]{0}); - } - - if ( visualizeResult && inliers.size() >= minNumInliers ) - { - ImagePlus rendered = AlignTools.visualizePair( - stDataA, stDataB, - new AffineTransform2D(), - AlignTools.modelToAffineTransform2D( modelGlobal ).inverse(), - smoothnessFactor ); - rendered.setTitle( stDataAname + "-" + stDataBname + "-inliers-" + inliers.size() + " (" + AlignTools.defaultGene + ")" ); - } - - // compute errors - // reset world coordinates & compute error - double error = Double.NaN, maxError = Double.NaN, minError = Double.NaN; - if ( inliers.size() > 0 ) - { - error = 0; - minError = Double.MAX_VALUE; - maxError = -Double.MAX_VALUE; - - for ( final PointMatch pm : inliers ) - { - pm.apply( modelGlobal ); - final double dist = Point.distance(pm.getP1(), pm.getP2()); - error += dist; - maxError = Math.max( maxError, dist ); - minError = Math.min( minError, dist ); - } - - error /= (double)inliers.size(); - } - - System.out.println( "errors: " + minError + "/" + error + "/" + maxError ); - } - - public static void main( String[] args ) throws IOException - { - final String path = Path.getPath(); - - new ImageJ(); + // adjust the locations to the global coordinate system + // and store the gene name it came from + for (final PointMatch pm : candidatesTmp) { + final Point p1 = pm.getP1(); + final Point p2 = pm.getP2(); + + for (int d = 0; d < finalInterval.numDimensions(); ++d) { + p1.getL()[d] = p1.getW()[d] = (p1.getL()[d] + finalInterval.min(d)) / scale; + p2.getL()[d] = p2.getW()[d] = (p2.getL()[d] + finalInterval.min(d)) / scale; + } + } + + // prefilter the candidates + final M model = modelPairwise.copy();//new InterpolatedAffineModel2D<>( new AffineModel2D(), new RigidModel2D(), 0.1 );//new RigidModel2D(); + final List inliers = consensus(candidatesTmp, model, minNumInliersPerGene, maxEpsilon); + + // reset world coordinates & compute error + double error = Double.NaN, maxError = Double.NaN, minError = Double.NaN; + if (inliers.size() > 0) { + error = 0; + minError = Double.MAX_VALUE; + maxError = -Double.MAX_VALUE; + + for (final PointMatch pm : inliers) { + final double dist = Point.distance(pm.getP1(), pm.getP2()); + error += dist; + maxError = Math.max(maxError, dist); + minError = Math.min(minError, dist); + } + + error /= (double) inliers.size(); + } + + for (final PointMatch pm : candidatesTmp) { + final Point p1 = pm.getP1(); + final Point p2 = pm.getP2(); + + for (int d = 0; d < finalInterval.numDimensions(); ++d) { + p1.getW()[d] = p1.getL()[d]; + p2.getW()[d] = p2.getL()[d]; + } + } + + if (inliers.size() > 0) { + allPerGeneInliers.addAll(inliers); + System.out.println(stDataAname + "-" + stDataBname + ": " + inliers.size() + "/" + candidatesTmp.size() + ", " + minError + "/" + error + "/" + maxError + ", " + ((PointST) inliers.get(0).getP1()).getGene()); + //System.out.println( ki + "-" + kj + ": " + inliers.size() + "/" + candidatesTmp.size() + ", " + ((PointST)inliers.get( 0 ).getP1()).getGene() + ", " ); + //GlobalOpt.visualizePair(stDataA, stDataB, new AffineTransform2D(), GlobalOpt.modelToAffineTransform2D( model ).inverse() ).setTitle( gene +"_" + inliers.size() );; + } + } + + return allPerGeneInliers; + }); + } + + final ExecutorService service = Threads.createFixedExecutorService(numThreads); + + try { + final List>> futures = service.invokeAll(tasks); + for (final Future> future : futures) + allCandidates.addAll(future.get()); + } catch (final InterruptedException | ExecutionException e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + + service.shutdown(); + + //final InterpolatedAffineModel2D model = new InterpolatedAffineModel2D<>( new AffineModel2D(), new RigidModel2D(), 0.1 );//new RigidModel2D(); + //final RigidModel2D model = new RigidModel2D(); + final ArrayList inliers = consensus(allCandidates, modelGlobal, minNumInliers, maxEpsilon); + + // the model that maps J to I + System.out.println(stDataAname + "\t" + stDataBname + "\t" + inliers.size() + "\t" + allCandidates.size() + "\t" + AlignTools.modelToAffineTransform2D(modelGlobal).inverse()); + + if (saveResult && inliers.size() >= minNumInliers) { + final HashSet genes = new HashSet(); + for (final PointMatch pm : inliers) + genes.add(((PointST) pm.getP1()).getGene()); + + final N5FSWriter n5 = N5IO.openN5write(n5File); + final String pairwiseGroupName = n5.groupPath("/", "matches", stDataAname + "-" + stDataBname); + if (n5.exists(pairwiseGroupName)) + n5.remove(pairwiseGroupName); + n5.createDataset( + pairwiseGroupName, + new long[]{1}, + new int[]{1}, + DataType.OBJECT, + new GzipCompression()); + + n5.setAttribute(pairwiseGroupName, "stDataAname", stDataAname); + n5.setAttribute(pairwiseGroupName, "stDataBname", stDataBname); + n5.setAttribute(pairwiseGroupName, "inliers", inliers.size()); + n5.setAttribute(pairwiseGroupName, "candidates", allCandidates.size()); + n5.setAttribute(pairwiseGroupName, "genes", genes); + + n5.writeSerializedBlock( + inliers, + pairwiseGroupName, + n5.getDatasetAttributes(pairwiseGroupName), + new long[]{0}); + } + + if (visualizeResult && inliers.size() >= minNumInliers) { + ImagePlus rendered = AlignTools.visualizePair( + stDataA, stDataB, + new AffineTransform2D(), + AlignTools.modelToAffineTransform2D(modelGlobal).inverse(), + smoothnessFactor); + rendered.setTitle(stDataAname + "-" + stDataBname + "-inliers-" + inliers.size() + " (" + AlignTools.defaultGene + ")"); + } + + // compute errors + // reset world coordinates & compute error + double error = Double.NaN, maxError = Double.NaN, minError = Double.NaN; + if (inliers.size() > 0) { + error = 0; + minError = Double.MAX_VALUE; + maxError = -Double.MAX_VALUE; + + for (final PointMatch pm : inliers) { + pm.apply(modelGlobal); + final double dist = Point.distance(pm.getP1(), pm.getP2()); + error += dist; + maxError = Math.max(maxError, dist); + minError = Math.min(minError, dist); + } + + error /= (double) inliers.size(); + } + + System.out.println("errors: " + minError + "/" + error + "/" + maxError); + } + + public static void main(String[] args) throws IOException { + final String path = Path.getPath(); + + new ImageJ(); /* ImagePlus imp = new ImagePlus( new File( path + "slide-seq/stack_small_.tif" ).getAbsolutePath() ); @@ -456,61 +644,59 @@ public static void main( String[] args ) throws IOException } */ - //final String[] pucks = new String[] { "Puck_180602_20", "Puck_180602_18", "Puck_180602_17", "Puck_180602_16", "Puck_180602_15", "Puck_180531_23", "Puck_180531_22", "Puck_180531_19", "Puck_180531_18", "Puck_180531_17", "Puck_180531_13", "Puck_180528_22", "Puck_180528_20" }; - //final String[] pucks = new String[] { "Puck_180531_23", "Puck_180531_22" }; - //final String[] pucks = new String[] { "Puck_180602_18", "Puck_180531_18" }; // 1-8 + //final String[] pucks = new String[] { "Puck_180602_20", "Puck_180602_18", "Puck_180602_17", "Puck_180602_16", "Puck_180602_15", "Puck_180531_23", "Puck_180531_22", "Puck_180531_19", "Puck_180531_18", "Puck_180531_17", "Puck_180531_13", "Puck_180528_22", "Puck_180528_20" }; + //final String[] pucks = new String[] { "Puck_180531_23", "Puck_180531_22" }; + //final String[] pucks = new String[] { "Puck_180602_18", "Puck_180531_18" }; // 1-8 - final File n5File = new File( path + "slide-seq-normalized.n5" ); - final N5FSReader n5 = N5IO.openN5( n5File ); - final List< String > pucks = N5IO.listAllDatasets( n5 ); + final File n5File = new File(path + "slide-seq-normalized.n5"); + final N5FSReader n5 = N5IO.openN5(n5File); + final List pucks = N5IO.listAllDatasets(n5); - final ArrayList< STData > puckData = new ArrayList(); - for ( final String puck : pucks ) - puckData.add( N5IO.readN5( n5, puck ) ); + final ArrayList puckData = new ArrayList(); + for (final String puck : pucks) + puckData.add(N5IO.readN5(n5, puck)); - // clear the alignment metadata - final String matchesGroupName = n5.groupPath( "/", "matches" ); - final N5FSWriter n5Writer = N5IO.openN5write( n5File ); - if (n5.exists(matchesGroupName)) - n5Writer.remove( matchesGroupName ); - n5Writer.createGroup( matchesGroupName ); + // clear the alignment metadata + final String matchesGroupName = n5.groupPath("/", "matches"); + final N5FSWriter n5Writer = N5IO.openN5write(n5File); + if (n5.exists(matchesGroupName)) + n5Writer.remove(matchesGroupName); + n5Writer.createGroup(matchesGroupName); - // visualize using the global transform - final double scale = 0.1; - final double maxEpsilon = 300; - final int minNumInliers = 12; - final int minNumInliersPerGene = 10; + // visualize using the global transform + final double scale = 0.1; + final double maxEpsilon = 300; + final int minNumInliers = 12; + final int minNumInliersPerGene = 10; - final double smoothnessFactor = 4.0; + final double smoothnessFactor = 4.0; - final SIFTParam p = new SIFTParam(); - final boolean saveResult = true; - final boolean visualizeResult = true; + final SIFTParam p = new SIFTParam(); + final boolean saveResult = true; + final boolean visualizeResult = true; - // multi-threading - final int numThreads = Threads.numThreads(); + // multi-threading + final int numThreads = Threads.numThreads(); - for ( int i = 0; i < pucks.size() - 1; ++i ) - { - for ( int j = i + 1; j < pucks.size(); ++j ) - { - if ( Math.abs( j - i ) > 2 ) - continue; + for (int i = 0; i < pucks.size() - 1; ++i) { + for (int j = i + 1; j < pucks.size(); ++j) { + if (Math.abs(j - i) > 2) + continue; - //final int ki = i; - //final int kj = j; + //final int ki = i; + //final int kj = j; - final STData stDataA = puckData.get(i); - final STData stDataB = puckData.get(j); + final STData stDataA = puckData.get(i); + final STData stDataB = puckData.get(j); - final String puckA = pucks.get( i ); - final String puckB = pucks.get( j ); + final String puckA = pucks.get(i); + final String puckB = pucks.get(j); - //System.out.println( new Date( System.currentTimeMillis() ) + ": Finding genes" ); + //System.out.println( new Date( System.currentTimeMillis() ) + ": Finding genes" ); - final List< String > genesToTest = Pairwise.genesToTest( stDataA, stDataB, 2000, numThreads ); - //for ( final String gene : genesToTest ) - // System.out.println( gene ); + final List genesToTest = Pairwise.genesToTest(stDataA, stDataB, 2000, numThreads); + //for ( final String gene : genesToTest ) + // System.out.println( gene ); /*final List< String > genesToTest = new ArrayList<>(); genesToTest.add( "Calm1" ); genesToTest.add( "Calm2" ); @@ -522,10 +708,10 @@ public static void main( String[] args ) throws IOException // check out ROD! */ - pairwiseSIFT(stDataA, puckA, stDataB, puckB, new RigidModel2D(), new RigidModel2D(), n5File, genesToTest, p, scale, smoothnessFactor, maxEpsilon, - minNumInliers, minNumInliersPerGene, saveResult, visualizeResult, numThreads); - } - } - System.out.println("done."); - } + pairwiseSIFT(stDataA, puckA, stDataB, puckB, new RigidModel2D(), new RigidModel2D(), n5File, genesToTest, p, scale, smoothnessFactor, maxEpsilon, + minNumInliers, minNumInliersPerGene, saveResult, visualizeResult, numThreads); + } + } + System.out.println("done."); + } } diff --git a/src/main/java/cmd/PairwiseSectionAligner.java b/src/main/java/cmd/PairwiseSectionAligner.java index 1e46cb4..034783e 100644 --- a/src/main/java/cmd/PairwiseSectionAligner.java +++ b/src/main/java/cmd/PairwiseSectionAligner.java @@ -1,12 +1,16 @@ package cmd; import java.io.File; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.concurrent.Callable; +import java.util.function.Supplier; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; import org.janelia.saalfeldlab.n5.N5FSWriter; import org.joml.Math; @@ -71,6 +75,9 @@ public class PairwiseSectionAligner implements Callable { @Option(names = {"--hidePairwiseRendering"}, required = false, description = "do not show pairwise renderings that apply the 2D rigid models (default: false - showing them)") private boolean hidePairwiseRendering = false; + @Option(names = {"--spark"}, required = false, description = "use spark") + private boolean sparkProcessing = true; + //-i /Users/spreibi/Documents/BIMSB/Publications/imglib2-st/slide-seq-test.n5 -d 'Puck_180602_20,Puck_180602_18,Puck_180602_17,Puck_180602_16,Puck_180602_15,Puck_180531_23,Puck_180531_22,Puck_180531_19,Puck_180531_18,Puck_180531_17,Puck_180531_13,Puck_180528_22,Puck_180528_20' -n 100 --overwrite @Override @@ -124,6 +131,13 @@ public Void call() throws Exception { stdata.add( N5IO.readN5( n5, dataset ) ); } } + JavaSparkContext sc = null; + + if(sparkProcessing){ + final SparkConf conf = new SparkConf().setAppName("PairwiseSectionAligner").setMaster("local"); + sc = new JavaSparkContext(conf); + sc.setLogLevel("ERROR"); + } // iterate once just to be sure we will not crash half way through because something exists for ( int i = 0; i < stdata.size() - 1; ++i ) @@ -273,14 +287,23 @@ public Void call() throws Exception { // hard case: -i /Users/spreibi/Documents/BIMSB/Publications/imglib2-st/slide-seq-test.n5 -d1 Puck_180602_15 -d2 Puck_180602_16 -n 30 // even harder: -i /Users/spreibi/Documents/BIMSB/Publications/imglib2-st/slide-seq-test.n5 -d1 Puck_180602_20 -d2 Puck_180602_18 -n 100 --overwrite - PairwiseSIFT.pairwiseSIFT( - stData1, dataset1, stData2, dataset2, - new RigidModel2D(), new RigidModel2D(), - n5File, new ArrayList<>( genesToTest ), - p, scale, smoothnessFactor, maxEpsilon, - minNumInliers, minNumInliersGene, - saveResult, visualizeResult, Threads.numThreads() ); - + if (sparkProcessing){ + PairwiseSIFT.sparkPairwiseSIFT( + input, dataset1, dataset2, + (Supplier & Serializable)(() -> new RigidModel2D()), new RigidModel2D(), + new ArrayList<>(genesToTest), + p, scale, smoothnessFactor, maxEpsilon, + minNumInliers, minNumInliersGene, + saveResult, visualizeResult, sc); + }else { + PairwiseSIFT.pairwiseSIFT( + stData1, dataset1, stData2, dataset2, + new RigidModel2D(), new RigidModel2D(), + n5File, new ArrayList<>(genesToTest), + p, scale, smoothnessFactor, maxEpsilon, + minNumInliers, minNumInliersGene, + saveResult, visualizeResult, Threads.numThreads()); + } System.out.println( "Took " + (System.currentTimeMillis() - time)/1000 + " sec." ); } diff --git a/src/main/java/util/spark/Serializers.java b/src/main/java/util/spark/Serializers.java new file mode 100644 index 0000000..16a13f8 --- /dev/null +++ b/src/main/java/util/spark/Serializers.java @@ -0,0 +1,18 @@ +package util.spark; + +import mpicbg.models.RigidModel2D; + +public class Serializers { + + public static double[] serializeRigidModel2D(RigidModel2D model) { + double[] result = new double[6]; + model.toArray(result); + return result; + } + + public static RigidModel2D deserializeRigidModel2D(double[] data) { + RigidModel2D result = new RigidModel2D(); + result.set(data[0], data[1], data[4], data[5]); + return result; + } +}