Skip to content

Commit 306376f

Browse files
committed
improve javadoc and compilation
1 parent 1f81d2a commit 306376f

File tree

1 file changed

+12
-71
lines changed

1 file changed

+12
-71
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/tensor/NDArrayBuilder.java

+12-71
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
import net.imglib2.Cursor;
2828
import net.imglib2.RandomAccessibleInterval;
2929
import net.imglib2.img.Img;
30-
import net.imglib2.type.Type;
30+
import net.imglib2.type.NativeType;
31+
import net.imglib2.type.numeric.RealType;
3132
import net.imglib2.type.numeric.integer.ByteType;
3233
import net.imglib2.type.numeric.integer.IntType;
3334
import net.imglib2.type.numeric.real.DoubleType;
3435
import net.imglib2.type.numeric.real.FloatType;
36+
import net.imglib2.util.Cast;
3537
import net.imglib2.util.Util;
3638
import net.imglib2.view.Views;
3739

@@ -60,26 +62,9 @@ public class NDArrayBuilder {
6062
* @return The {@link NDArray} built from the {@link Tensor}.
6163
* @throws IllegalArgumentException If the tensor type is not supported.
6264
*/
63-
public static NDArray build(Tensor tensor, NDManager manager)
64-
throws IllegalArgumentException
65-
{
66-
// Create an Icy sequence of the same type of the tensor
67-
if (Util.getTypeFromInterval(tensor.getData()) instanceof ByteType) {
68-
return buildFromTensorByte(tensor.getData(), manager);
69-
}
70-
else if (Util.getTypeFromInterval(tensor.getData()) instanceof IntType) {
71-
return buildFromTensorInt(tensor.getData(), manager);
72-
}
73-
else if (Util.getTypeFromInterval(tensor.getData()) instanceof FloatType) {
74-
return buildFromTensorFloat(tensor.getData(), manager);
75-
}
76-
else if (Util.getTypeFromInterval(tensor.getData()) instanceof DoubleType) {
77-
return buildFromTensorDouble(tensor.getData(), manager);
78-
}
79-
else {
80-
throw new IllegalArgumentException("Unsupported tensor type: " + tensor
81-
.getDataType());
82-
}
65+
public static <T extends RealType<T> & NativeType<T>>
66+
NDArray build(Tensor<T> tensor, NDManager manager) throws IllegalArgumentException {
67+
return build(tensor.getData(), manager);
8368
}
8469

8570
/**
@@ -94,42 +79,28 @@ else if (Util.getTypeFromInterval(tensor.getData()) instanceof DoubleType) {
9479
* @return The {@link NDArray} built from the {@link RandomAccessibleInterval}.
9580
* @throws IllegalArgumentException if the {@link RandomAccessibleInterval} is not supported
9681
*/
97-
public static <T extends Type<T>> NDArray build(
98-
RandomAccessibleInterval<T> tensor, NDManager manager)
82+
public static <T extends RealType<T> & NativeType<T>>
83+
NDArray build(RandomAccessibleInterval<T> tensor, NDManager manager)
9984
throws IllegalArgumentException
10085
{
10186
if (Util.getTypeFromInterval(tensor) instanceof ByteType) {
102-
return buildFromTensorByte((RandomAccessibleInterval<ByteType>) tensor,
103-
manager);
87+
return buildFromTensorByte(Cast.unchecked(tensor), manager);
10488
}
10589
else if (Util.getTypeFromInterval(tensor) instanceof IntType) {
106-
return buildFromTensorInt((RandomAccessibleInterval<IntType>) tensor,
107-
manager);
90+
return buildFromTensorInt(Cast.unchecked(tensor), manager);
10891
}
10992
else if (Util.getTypeFromInterval(tensor) instanceof FloatType) {
110-
return buildFromTensorFloat((RandomAccessibleInterval<FloatType>) tensor,
111-
manager);
93+
return buildFromTensorFloat(Cast.unchecked(tensor), manager);
11294
}
11395
else if (Util.getTypeFromInterval(tensor) instanceof DoubleType) {
114-
return buildFromTensorDouble(
115-
(RandomAccessibleInterval<DoubleType>) tensor, manager);
96+
return buildFromTensorDouble(Cast.unchecked(tensor), manager);
11697
}
11798
else {
11899
throw new IllegalArgumentException("Unsupported tensor type: " + Util
119100
.getTypeFromInterval(tensor).getClass().toString());
120101
}
121102
}
122103

123-
/**
124-
* Builds a {@link NDArray} from a signed byte-typed
125-
* {@link RandomAccessibleInterval}.
126-
*
127-
* @param tensor
128-
* the {@link RandomAccessibleInterval} that will be copied into an {@link NDArray}
129-
* @param manager
130-
* {@link NDManager} needed to create a {@link NDArray}
131-
* @return The {@link NDArray} built from the tensor of type {@link ByteType}.
132-
*/
133104
private static NDArray buildFromTensorByte(
134105
RandomAccessibleInterval<ByteType> tensor, NDManager manager)
135106
{
@@ -156,16 +127,6 @@ private static NDArray buildFromTensorByte(
156127
return ndarray;
157128
}
158129

159-
/**
160-
* Builds a {@link NDArray} from a signed integer-typed
161-
* {@link RandomAccessibleInterval}.
162-
*
163-
* @param tensor
164-
* the {@link RandomAccessibleInterval} that will be copied into an {@link NDArray}
165-
* @param manager
166-
* {@link NDManager} needed to create a {@link NDArray}
167-
* @return The {@link NDArray} built from the tensor of type {@link IntType}.
168-
*/
169130
private static NDArray buildFromTensorInt(
170131
RandomAccessibleInterval<IntType> tensor, NDManager manager)
171132
{
@@ -192,16 +153,6 @@ private static NDArray buildFromTensorInt(
192153
return ndarray;
193154
}
194155

195-
/**
196-
* Builds a {@link NDArray} from a signed float-typed
197-
* {@link RandomAccessibleInterval}.
198-
*
199-
* @param tensor
200-
* the {@link RandomAccessibleInterval} that will be copied into an {@link NDArray}
201-
* @param manager
202-
* {@link NDManager} needed to create a {@link NDArray}
203-
* @return The {@link NDArray} built from the tensor of type {@link FloatType}.
204-
*/
205156
private static NDArray buildFromTensorFloat(
206157
RandomAccessibleInterval<FloatType> tensor, NDManager manager)
207158
{
@@ -228,16 +179,6 @@ private static NDArray buildFromTensorFloat(
228179
return ndarray;
229180
}
230181

231-
/**
232-
* Builds a {@link NDArray} from a signed double-typed
233-
* {@link RandomAccessibleInterval}.
234-
*
235-
* @param tensor
236-
* the {@link RandomAccessibleInterval} that will be copied into an {@link NDArray}
237-
* @param manager
238-
* {@link NDManager} needed to create a {@link NDArray}
239-
* @return The {@link NDArray} built from the tensor of type {@link DoubleType}.
240-
*/
241182
private static NDArray buildFromTensorDouble(
242183
RandomAccessibleInterval<DoubleType> tensor, NDManager manager)
243184
{

0 commit comments

Comments
 (0)