Skip to content

Commit d76e8ee

Browse files
committed
Change TensorEigenValues to use multi-threaded LoopBuilder.
Simplify the code by using multi-threaded LoopBuilder. Many existing methods become multi-threaded by default.
1 parent 02028f4 commit d76e8ee

File tree

1 file changed

+17
-78
lines changed

1 file changed

+17
-78
lines changed

src/main/java/net/imglib2/algorithm/linalg/eigen/TensorEigenValues.java

+17-78
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,19 @@
3434

3535
package net.imglib2.algorithm.linalg.eigen;
3636

37-
import java.util.ArrayList;
38-
import java.util.List;
39-
import java.util.concurrent.Callable;
40-
import java.util.concurrent.ExecutionException;
4137
import java.util.concurrent.ExecutorService;
42-
import java.util.concurrent.Future;
4338

44-
import net.imglib2.Cursor;
45-
import net.imglib2.FinalInterval;
4639
import net.imglib2.RandomAccessibleInterval;
4740
import net.imglib2.img.Img;
4841
import net.imglib2.img.ImgFactory;
42+
import net.imglib2.loops.LoopBuilder;
43+
import net.imglib2.parallel.Parallelization;
44+
import net.imglib2.parallel.TaskExecutors;
4945
import net.imglib2.type.numeric.ComplexType;
5046
import net.imglib2.type.numeric.RealType;
51-
import net.imglib2.view.IntervalView;
5247
import net.imglib2.view.Views;
53-
import net.imglib2.view.composite.NumericComposite;
54-
import net.imglib2.view.composite.RealComposite;
48+
import net.imglib2.view.composite.CompositeIntervalView;
49+
import net.imglib2.view.composite.GenericComposite;
5550

5651
/**
5752
*
@@ -286,82 +281,26 @@ public static < T extends RealType< T >, U extends ComplexType< U > > RandomAcce
286281
final ExecutorService es )
287282
{
288283

289-
assert nTasks > 0: "Passed nTasks < 1";
290-
291-
final int tensorDims = tensor.numDimensions();
292-
293-
long dimensionMax = Long.MIN_VALUE;
294-
int dimensionArgMax = -1;
295-
296-
for ( int d = 0; d < tensorDims - 1; ++d )
297-
{
298-
final long size = tensor.dimension( d );
299-
if ( size > dimensionMax )
300-
{
301-
dimensionMax = size;
302-
dimensionArgMax = d;
303-
}
304-
}
305-
306-
final long stepSize = Math.max( dimensionMax / nTasks, 1 );
307-
final long stepSizeMinusOne = stepSize - 1;
308-
final long max = dimensionMax - 1;
309-
310-
final ArrayList< Callable< RandomAccessibleInterval< U > > > tasks = new ArrayList<>();
311-
for ( long currentMin = 0; currentMin < dimensionMax; currentMin += stepSize )
312-
{
313-
final long currentMax = Math.min( currentMin + stepSizeMinusOne, max );
314-
final long[] minT = new long[ tensorDims ];
315-
final long[] maxT = new long[ tensorDims ];
316-
final long[] minE = new long[ tensorDims ];
317-
final long[] maxE = new long[ tensorDims ];
318-
tensor.min( minT );
319-
tensor.max( maxT );
320-
eigenvalues.min( minE );
321-
eigenvalues.max( maxE );
322-
minE[ dimensionArgMax ] = minT[ dimensionArgMax ] = currentMin;
323-
maxE[ dimensionArgMax ] = maxT[ dimensionArgMax ] = currentMax;
324-
final IntervalView< T > currentTensor = Views.interval( tensor, new FinalInterval( minT, maxT ) );
325-
final IntervalView< U > currentEigenvalues = Views.interval( eigenvalues, new FinalInterval( minE, maxE ) );
326-
tasks.add( () -> calculateEigenValuesImpl( currentTensor, currentEigenvalues, ev.copy() ) );
327-
}
328-
329-
330-
try
331-
{
332-
final List< Future< RandomAccessibleInterval< U > > > futures = es.invokeAll( tasks );
333-
for ( final Future< RandomAccessibleInterval< U > > f : futures )
334-
try
335-
{
336-
f.get();
337-
}
338-
catch ( final ExecutionException e )
339-
{
340-
// TODO Auto-generated catch block
341-
e.printStackTrace();
342-
}
343-
}
344-
catch ( final InterruptedException e )
345-
{
346-
// TODO Auto-generated catch block
347-
e.printStackTrace();
348-
}
349-
350-
return eigenvalues;
351-
352-
284+
assert nTasks > 0 : "Passed nTasks < 1";
353285

286+
return Parallelization.runWithExecutor( TaskExecutors.forExecutorServiceAndNumTasks( es, nTasks ),
287+
() -> calculateEigenValues( tensor, eigenvalues, ev ) );
354288
}
355289

356290
private static < T extends RealType< T >, U extends ComplexType< U > > RandomAccessibleInterval< U > calculateEigenValuesImpl(
357291
final RandomAccessibleInterval< T > tensor,
358292
final RandomAccessibleInterval< U > eigenvalues,
359293
final EigenValues< T, U > ev )
360294
{
361-
final Cursor< RealComposite< T > > m = Views.iterable( Views.collapseReal( tensor ) ).cursor();
362-
final Cursor< NumericComposite< U > > e = Views.iterable( Views.collapseNumeric( eigenvalues ) ).cursor();
363-
while ( m.hasNext() )
364-
ev.compute( m.next(), e.next() );
295+
RandomAccessibleInterval< ? extends GenericComposite< T > > tensorVectors = Views.collapse( tensor );
296+
CompositeIntervalView< U, ? extends GenericComposite< U > > eigenvaluesVectors = Views.collapse( eigenvalues );
297+
LoopBuilder.setImages( tensorVectors, eigenvaluesVectors )
298+
.multiThreaded()
299+
.forEachChunk( chunk -> {
300+
EigenValues< T, U > copy = ev.copy();
301+
chunk.forEachPixel( copy::compute );
302+
return null;
303+
} );
365304
return eigenvalues;
366305
}
367306

0 commit comments

Comments
 (0)