Skip to content

Commit 97a461f

Browse files
committed
WIP
1 parent d76e8ee commit 97a461f

File tree

10 files changed

+424
-995
lines changed

10 files changed

+424
-995
lines changed

src/main/java/net/imglib2/algorithm/gradient/PartialDerivative.java

+36-59
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,18 @@
3434

3535
package net.imglib2.algorithm.gradient;
3636

37-
import java.util.ArrayList;
38-
import java.util.List;
39-
import java.util.concurrent.Callable;
4037
import java.util.concurrent.ExecutionException;
4138
import java.util.concurrent.ExecutorService;
42-
import java.util.concurrent.Future;
4339

4440
import net.imglib2.Cursor;
45-
import net.imglib2.FinalInterval;
4641
import net.imglib2.RandomAccessible;
4742
import net.imglib2.RandomAccessibleInterval;
4843
import net.imglib2.loops.LoopBuilder;
44+
import net.imglib2.parallel.TaskExecutor;
45+
import net.imglib2.parallel.Parallelization;
46+
import net.imglib2.parallel.TaskExecutors;
4947
import net.imglib2.type.numeric.NumericType;
5048
import net.imglib2.util.Intervals;
51-
import net.imglib2.view.IntervalView;
5249
import net.imglib2.view.Views;
5350

5451
/**
@@ -98,7 +95,7 @@ public static < T extends NumericType< T > > void gradientCentralDifference2( fi
9895
* @param source
9996
* source image, has to provide valid data in the interval of the
10097
* gradient image plus a one pixel border in dimension.
101-
* @param gradient
98+
* @param result
10299
* output image
103100
* @param dimension
104101
* along which dimension the partial derivatives are computed
@@ -110,57 +107,42 @@ public static < T extends NumericType< T > > void gradientCentralDifference2( fi
110107
*/
111108
public static < T extends NumericType< T > > void gradientCentralDifferenceParallel(
112109
final RandomAccessible< T > source,
113-
final RandomAccessibleInterval< T > gradient,
110+
final RandomAccessibleInterval< T > result,
114111
final int dimension,
115112
final int nTasks,
116113
final ExecutorService es ) throws InterruptedException, ExecutionException
117114
{
118-
final int nDim = source.numDimensions();
119-
if ( nDim < 2 )
120-
{
121-
gradientCentralDifference( source, gradient, dimension );
122-
return;
123-
}
124-
125-
long dimensionMax = Long.MIN_VALUE;
126-
int dimensionArgMax = -1;
127-
128-
for ( int d = 0; d < nDim; ++d )
129-
{
130-
final long size = gradient.dimension( d );
131-
if ( d != dimension && size > dimensionMax )
132-
{
133-
dimensionMax = size;
134-
dimensionArgMax = d;
135-
}
136-
}
137-
138-
final long stepSize = Math.max( dimensionMax / nTasks, 1 );
139-
final long stepSizeMinusOne = stepSize - 1;
140-
final long min = gradient.min( dimensionArgMax );
141-
final long max = gradient.max( dimensionArgMax );
142-
143-
final ArrayList< Callable< Void > > tasks = new ArrayList<>();
144-
for ( long currentMin = min, minZeroBase = 0; minZeroBase < dimensionMax; currentMin += stepSize, minZeroBase += stepSize )
145-
{
146-
final long currentMax = Math.min( currentMin + stepSizeMinusOne, max );
147-
final long[] mins = new long[ nDim ];
148-
final long[] maxs = new long[ nDim ];
149-
gradient.min( mins );
150-
gradient.max( maxs );
151-
mins[ dimensionArgMax ] = currentMin;
152-
maxs[ dimensionArgMax ] = currentMax;
153-
final IntervalView< T > currentInterval = Views.interval( gradient, new FinalInterval( mins, maxs ) );
154-
tasks.add( () -> {
155-
gradientCentralDifference( source, currentInterval, dimension );
156-
return null;
157-
} );
158-
}
115+
TaskExecutor taskExecutor = TaskExecutors.forExecutorServiceAndNumTasks( es, nTasks );
116+
Parallelization.runWithExecutor( taskExecutor, () -> {
117+
gradientCentralDerivativeParallel( source, result, dimension );
118+
} );
119+
}
159120

160-
final List< Future< Void > > futures = es.invokeAll( tasks );
121+
/**
122+
* Compute the partial derivative (central difference approximation) of source
123+
* in a particular dimension:
124+
* {@code d_f( x ) = ( f( x + e ) - f( x - e ) ) / 2},
125+
* where {@code e} is the unit vector along that dimension.
126+
*
127+
* @param source
128+
* source image, has to provide valid data in the interval of the
129+
* gradient image plus a one pixel border in dimension.
130+
* @param result
131+
* output image
132+
* @param dimension
133+
* along which dimension the partial derivatives are computed
134+
*/
135+
private static <T extends NumericType< T >> void gradientCentralDerivativeParallel( RandomAccessible<T> source,
136+
RandomAccessibleInterval<T> result, int dimension )
137+
{
138+
final RandomAccessibleInterval<T> back = Views.interval( source, Intervals.translate( result, -1, dimension ) );
139+
final RandomAccessibleInterval<T> front = Views.interval( source, Intervals.translate( result, 1, dimension ) );
161140

162-
for ( final Future< Void > f : futures )
163-
f.get();
141+
LoopBuilder.setImages( result, back, front ).multiThreaded().forEachPixel( ( r, b, f ) -> {
142+
r.set( f );
143+
r.sub( b );
144+
r.mul( 0.5 );
145+
} );
164146
}
165147

166148
// fast version
@@ -181,13 +163,8 @@ public static < T extends NumericType< T > > void gradientCentralDifferenceParal
181163
public static < T extends NumericType< T > > void gradientCentralDifference( final RandomAccessible< T > source,
182164
final RandomAccessibleInterval< T > result, final int dimension )
183165
{
184-
final RandomAccessibleInterval< T > back = Views.interval( source, Intervals.translate( result, -1, dimension ) );
185-
final RandomAccessibleInterval< T > front = Views.interval( source, Intervals.translate( result, 1, dimension ) );
186-
187-
LoopBuilder.setImages( result, back, front ).forEachPixel( ( r, b, f ) -> {
188-
r.set( f );
189-
r.sub( b );
190-
r.mul( 0.5 );
166+
Parallelization.runSingleThreaded( () -> {
167+
gradientCentralDerivativeParallel( source, result, dimension );
191168
} );
192169
}
193170

src/main/java/net/imglib2/algorithm/localextrema/LocalExtrema.java

+38-60
Original file line numberDiff line numberDiff line change
@@ -33,33 +33,35 @@
3333
*/
3434
package net.imglib2.algorithm.localextrema;
3535

36-
import java.util.ArrayList;
37-
import java.util.Arrays;
38-
import java.util.List;
39-
import java.util.concurrent.Callable;
40-
import java.util.concurrent.ExecutionException;
41-
import java.util.concurrent.ExecutorService;
42-
import java.util.concurrent.Future;
43-
import java.util.stream.IntStream;
44-
import java.util.stream.LongStream;
45-
46-
import net.imglib2.Cursor;
47-
import net.imglib2.FinalInterval;
4836
import net.imglib2.Interval;
4937
import net.imglib2.Localizable;
5038
import net.imglib2.Point;
39+
import net.imglib2.RandomAccess;
5140
import net.imglib2.RandomAccessible;
5241
import net.imglib2.RandomAccessibleInterval;
5342
import net.imglib2.Sampler;
5443
import net.imglib2.algorithm.neighborhood.Neighborhood;
5544
import net.imglib2.algorithm.neighborhood.RectangleShape;
5645
import net.imglib2.algorithm.neighborhood.Shape;
46+
import net.imglib2.converter.readwrite.WriteConvertedRandomAccessible;
47+
import net.imglib2.loops.LoopBuilder;
48+
import net.imglib2.parallel.Parallelization;
49+
import net.imglib2.parallel.TaskExecutor;
50+
import net.imglib2.parallel.TaskExecutors;
5751
import net.imglib2.util.ConstantUtils;
58-
import net.imglib2.util.Intervals;
5952
import net.imglib2.util.ValuePair;
6053
import net.imglib2.view.IntervalView;
6154
import net.imglib2.view.Views;
6255

56+
import java.util.ArrayList;
57+
import java.util.Arrays;
58+
import java.util.Collection;
59+
import java.util.List;
60+
import java.util.concurrent.ExecutionException;
61+
import java.util.concurrent.ExecutorService;
62+
import java.util.stream.IntStream;
63+
import java.util.stream.LongStream;
64+
6365
/**
6466
* Provides {@link #findLocalExtrema} to find pixels that are extrema in their
6567
* local neighborhood.
@@ -320,38 +322,8 @@ public static < P, T > List< P > findLocalExtrema(
320322
final int numTasks,
321323
final int splitDim ) throws InterruptedException, ExecutionException
322324
{
323-
324-
final long[] min = Intervals.minAsLongArray( interval );
325-
final long[] max = Intervals.maxAsLongArray( interval );
326-
327-
final long splitDimSize = interval.dimension( splitDim );
328-
final long splitDimMax = max[ splitDim ];
329-
final long splitDimMin = min[ splitDim ];
330-
final long taskSize = Math.max( splitDimSize / numTasks, 1 );
331-
332-
final ArrayList< Callable< List< P > > > tasks = new ArrayList<>();
333-
334-
for ( long start = splitDimMin, stop = splitDimMin + taskSize - 1; start <= splitDimMax; start += taskSize, stop += taskSize )
335-
{
336-
final long s = start;
337-
// need max here instead of dimension for constructor of
338-
// FinalInterval
339-
final long S = Math.min( stop, splitDimMax );
340-
tasks.add( () -> {
341-
final long[] localMin = min.clone();
342-
final long[] localMax = max.clone();
343-
localMin[ splitDim ] = s;
344-
localMax[ splitDim ] = S;
345-
return findLocalExtrema( source, new FinalInterval( localMin, localMax ), localNeighborhoodCheck, shape );
346-
} );
347-
}
348-
349-
final ArrayList< P > extrema = new ArrayList<>();
350-
final List< Future< List< P > > > futures = service.invokeAll( tasks );
351-
for ( final Future< List< P > > f : futures )
352-
extrema.addAll( f.get() );
353-
return extrema;
354-
325+
TaskExecutor taskExecutor = TaskExecutors.forExecutorServiceAndNumTasks( service, numTasks );
326+
return Parallelization.runWithExecutor( taskExecutor, () -> findLocalExtrema( source, interval, localNeighborhoodCheck, shape ) );
355327
}
356328

357329
/**
@@ -470,22 +442,28 @@ public static < P, T > List< P > findLocalExtrema(
470442
final LocalNeighborhoodCheck< P, T > localNeighborhoodCheck,
471443
final Shape shape )
472444
{
445+
WriteConvertedRandomAccessible< T, RandomAccess< T > > randomAccessible = new WriteConvertedRandomAccessible<>( source, sampler -> (RandomAccess< T >) sampler );
446+
RandomAccessibleInterval< RandomAccess< T > > centers = Views.interval( randomAccessible, interval);
447+
RandomAccessibleInterval< Neighborhood< T > > neighborhoods = Views.interval( shape.neighborhoodsRandomAccessible( source ), interval );
448+
List< List< P > > extremas = LoopBuilder.setImages( centers, neighborhoods ).multiThreaded().forEachChunk( chunk -> {
449+
List< P > extrema = new ArrayList<>();
450+
chunk.forEachPixel( ( center, neighborhood ) -> {
451+
P p = localNeighborhoodCheck.check( center, neighborhood );
452+
if ( p != null )
453+
extrema.add( p );
454+
} );
455+
return extrema;
456+
} );
457+
return concatenate( extremas );
458+
}
473459

474-
final IntervalView< T > sourceInterval = Views.interval( source, interval );
475-
476-
final ArrayList< P > extrema = new ArrayList<>();
477-
478-
final Cursor< T > center = Views.flatIterable( sourceInterval ).cursor();
479-
for ( final Neighborhood< T > neighborhood : shape.neighborhoods( sourceInterval ) )
480-
{
481-
center.fwd();
482-
final P p = localNeighborhoodCheck.check( center, neighborhood );
483-
if ( p != null )
484-
extrema.add( p );
485-
}
486-
487-
return extrema;
488-
460+
private static < P > List<P> concatenate( Collection<List<P>> lists )
461+
{
462+
int size = lists.stream().mapToInt( List::size ).sum();
463+
List< P > result = new ArrayList<>( size );
464+
for ( List< P > list : lists )
465+
result.addAll( list );
466+
return result;
489467
}
490468

491469
/**

0 commit comments

Comments
 (0)