|
34 | 34 |
|
35 | 35 | package net.imglib2.algorithm.linalg.eigen;
|
36 | 36 |
|
37 |
| -import java.util.ArrayList; |
38 |
| -import java.util.List; |
39 |
| -import java.util.concurrent.Callable; |
40 |
| -import java.util.concurrent.ExecutionException; |
41 | 37 | import java.util.concurrent.ExecutorService;
|
42 |
| -import java.util.concurrent.Future; |
43 | 38 |
|
44 |
| -import net.imglib2.Cursor; |
45 |
| -import net.imglib2.FinalInterval; |
46 | 39 | import net.imglib2.RandomAccessibleInterval;
|
47 | 40 | import net.imglib2.img.Img;
|
48 | 41 | import net.imglib2.img.ImgFactory;
|
| 42 | +import net.imglib2.loops.LoopBuilder; |
| 43 | +import net.imglib2.parallel.Parallelization; |
| 44 | +import net.imglib2.parallel.TaskExecutors; |
49 | 45 | import net.imglib2.type.numeric.ComplexType;
|
50 | 46 | import net.imglib2.type.numeric.RealType;
|
51 |
| -import net.imglib2.view.IntervalView; |
52 | 47 | import net.imglib2.view.Views;
|
53 |
| -import net.imglib2.view.composite.NumericComposite; |
54 |
| -import net.imglib2.view.composite.RealComposite; |
| 48 | +import net.imglib2.view.composite.CompositeIntervalView; |
| 49 | +import net.imglib2.view.composite.GenericComposite; |
55 | 50 |
|
56 | 51 | /**
|
57 | 52 | *
|
@@ -286,82 +281,26 @@ public static < T extends RealType< T >, U extends ComplexType< U > > RandomAcce
|
286 | 281 | final ExecutorService es )
|
287 | 282 | {
|
288 | 283 |
|
289 |
| - assert nTasks > 0: "Passed nTasks < 1"; |
290 |
| - |
291 |
| - final int tensorDims = tensor.numDimensions(); |
292 |
| - |
293 |
| - long dimensionMax = Long.MIN_VALUE; |
294 |
| - int dimensionArgMax = -1; |
295 |
| - |
296 |
| - for ( int d = 0; d < tensorDims - 1; ++d ) |
297 |
| - { |
298 |
| - final long size = tensor.dimension( d ); |
299 |
| - if ( size > dimensionMax ) |
300 |
| - { |
301 |
| - dimensionMax = size; |
302 |
| - dimensionArgMax = d; |
303 |
| - } |
304 |
| - } |
305 |
| - |
306 |
| - final long stepSize = Math.max( dimensionMax / nTasks, 1 ); |
307 |
| - final long stepSizeMinusOne = stepSize - 1; |
308 |
| - final long max = dimensionMax - 1; |
309 |
| - |
310 |
| - final ArrayList< Callable< RandomAccessibleInterval< U > > > tasks = new ArrayList<>(); |
311 |
| - for ( long currentMin = 0; currentMin < dimensionMax; currentMin += stepSize ) |
312 |
| - { |
313 |
| - final long currentMax = Math.min( currentMin + stepSizeMinusOne, max ); |
314 |
| - final long[] minT = new long[ tensorDims ]; |
315 |
| - final long[] maxT = new long[ tensorDims ]; |
316 |
| - final long[] minE = new long[ tensorDims ]; |
317 |
| - final long[] maxE = new long[ tensorDims ]; |
318 |
| - tensor.min( minT ); |
319 |
| - tensor.max( maxT ); |
320 |
| - eigenvalues.min( minE ); |
321 |
| - eigenvalues.max( maxE ); |
322 |
| - minE[ dimensionArgMax ] = minT[ dimensionArgMax ] = currentMin; |
323 |
| - maxE[ dimensionArgMax ] = maxT[ dimensionArgMax ] = currentMax; |
324 |
| - final IntervalView< T > currentTensor = Views.interval( tensor, new FinalInterval( minT, maxT ) ); |
325 |
| - final IntervalView< U > currentEigenvalues = Views.interval( eigenvalues, new FinalInterval( minE, maxE ) ); |
326 |
| - tasks.add( () -> calculateEigenValuesImpl( currentTensor, currentEigenvalues, ev.copy() ) ); |
327 |
| - } |
328 |
| - |
329 |
| - |
330 |
| - try |
331 |
| - { |
332 |
| - final List< Future< RandomAccessibleInterval< U > > > futures = es.invokeAll( tasks ); |
333 |
| - for ( final Future< RandomAccessibleInterval< U > > f : futures ) |
334 |
| - try |
335 |
| - { |
336 |
| - f.get(); |
337 |
| - } |
338 |
| - catch ( final ExecutionException e ) |
339 |
| - { |
340 |
| - // TODO Auto-generated catch block |
341 |
| - e.printStackTrace(); |
342 |
| - } |
343 |
| - } |
344 |
| - catch ( final InterruptedException e ) |
345 |
| - { |
346 |
| - // TODO Auto-generated catch block |
347 |
| - e.printStackTrace(); |
348 |
| - } |
349 |
| - |
350 |
| - return eigenvalues; |
351 |
| - |
352 |
| - |
| 284 | + assert nTasks > 0 : "Passed nTasks < 1"; |
353 | 285 |
|
| 286 | + return Parallelization.runWithExecutor( TaskExecutors.forExecutorServiceAndNumTasks( es, nTasks ), |
| 287 | + () -> calculateEigenValues( tensor, eigenvalues, ev ) ); |
354 | 288 | }
|
355 | 289 |
|
356 | 290 | private static < T extends RealType< T >, U extends ComplexType< U > > RandomAccessibleInterval< U > calculateEigenValuesImpl(
|
357 | 291 | final RandomAccessibleInterval< T > tensor,
|
358 | 292 | final RandomAccessibleInterval< U > eigenvalues,
|
359 | 293 | final EigenValues< T, U > ev )
|
360 | 294 | {
|
361 |
| - final Cursor< RealComposite< T > > m = Views.iterable( Views.collapseReal( tensor ) ).cursor(); |
362 |
| - final Cursor< NumericComposite< U > > e = Views.iterable( Views.collapseNumeric( eigenvalues ) ).cursor(); |
363 |
| - while ( m.hasNext() ) |
364 |
| - ev.compute( m.next(), e.next() ); |
| 295 | + RandomAccessibleInterval< ? extends GenericComposite< T > > tensorVectors = Views.collapse( tensor ); |
| 296 | + CompositeIntervalView< U, ? extends GenericComposite< U > > eigenvaluesVectors = Views.collapse( eigenvalues ); |
| 297 | + LoopBuilder.setImages( tensorVectors, eigenvaluesVectors ) |
| 298 | + .multiThreaded() |
| 299 | + .forEachChunk( chunk -> { |
| 300 | + EigenValues< T, U > copy = ev.copy(); |
| 301 | + chunk.forEachPixel( copy::compute ); |
| 302 | + return null; |
| 303 | + } ); |
365 | 304 | return eigenvalues;
|
366 | 305 | }
|
367 | 306 |
|
|
0 commit comments