Skip to content

Commit b39bd2f

Browse files
committed
Use better conversion functions
Signed-off-by: Ryan Nett <[email protected]>
1 parent 59750dc commit b39bd2f

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GradientAdapterHelpers.java

+35-23
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
/*
2-
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
33
4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
77
8-
http://www.apache.org/licenses/LICENSE-2.0
8+
http://www.apache.org/licenses/LICENSE-2.0
99
10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
15-
==============================================================================
16-
*/
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
1717
package org.tensorflow;
1818

19+
import static org.tensorflow.internal.c_api.global.tensorflow.OutputsFromTFOutputs;
20+
import static org.tensorflow.internal.c_api.global.tensorflow.TFOutputsFromOutputs;
1921
import static org.tensorflow.internal.c_api.global.tensorflow.ToOperation;
2022

2123
import java.util.ArrayList;
2224
import java.util.List;
2325
import org.bytedeco.javacpp.PointerScope;
24-
import org.tensorflow.internal.c_api.NativeOutput;
2526
import org.tensorflow.internal.c_api.NativeOutputVector;
2627
import org.tensorflow.internal.c_api.Node;
28+
import org.tensorflow.internal.c_api.TF_Output;
2729

2830
/**
2931
* Helpers for {@link org.tensorflow.op.TypedGradientAdapter} and {@link
@@ -34,39 +36,49 @@ public class GradientAdapterHelpers {
3436
/**
3537
* Convert a array of native outputs to a list of {@link Output}s.
3638
*
37-
* @param g the graph the outputs are in
39+
* @param g the graph the outputs are in
3840
* @param nativeOutputs the native outputs to convert
3941
*/
4042
public static List<Output<?>> fromNativeOutputs(Graph g, NativeOutputVector nativeOutputs) {
43+
TF_Output outputs = new TF_Output(nativeOutputs.size());
44+
TFOutputsFromOutputs(nativeOutputs, outputs);
4145
List<Output<?>> gradInputs = new ArrayList<>((int) nativeOutputs.size());
4246
for (int i = 0; i < nativeOutputs.size(); i++) {
43-
NativeOutput output = nativeOutputs.get(i);
44-
gradInputs.add(new Output<>(getGraphOp(g, output.node()),
45-
output.index()));
47+
TF_Output output = outputs.position(i);
48+
gradInputs.add(new Output<>(new GraphOperation(g, output.oper()), output.index()));
4649
}
4750
return gradInputs;
4851
}
4952

5053
/**
5154
* Put the Java outputs into the array of native outputs, resizing it to the necessary size.
5255
*
53-
* @param outputs the outputs to put
56+
* @param outputs the outputs to put
5457
* @param nativeOutputs the native array to put the outputs into
5558
*/
56-
public static void putToNativeOutputs(List<Operand<?>> outputs,
57-
NativeOutputVector nativeOutputs) {
59+
public static void putToNativeOutputs(
60+
List<Operand<?>> outputs, NativeOutputVector nativeOutputs) {
5861
nativeOutputs.resize(outputs.size());
62+
63+
TF_Output tempOutputs = new TF_Output(outputs.size());
5964
for (int i = 0; i < outputs.size(); i++) {
6065
Output<?> output = outputs.get(i).asOutput();
61-
Node node = ((GraphOperation) output.op()).getUnsafeNativeHandle().node();
62-
nativeOutputs.put(i, new NativeOutput(node, output.index()));
66+
GraphOperation graphOp = (GraphOperation) output.op();
67+
tempOutputs
68+
.position(i)
69+
.put(new TF_Output().oper(graphOp.getUnsafeNativeHandle()).index(output.index()));
70+
}
71+
72+
NativeOutputVector temp = OutputsFromTFOutputs(tempOutputs, outputs.size());
73+
for (int i = 0; i < outputs.size(); i++) {
74+
nativeOutputs.put(i, temp.get(i));
6375
}
6476
}
6577

6678
/**
6779
* Make a {@link GraphOperation} from a native {@link Node}
6880
*
69-
* @param g the graph the operation is in
81+
* @param g the graph the operation is in
7082
* @param node the native node
7183
* @return a graph operation with the underlying native node
7284
*/

0 commit comments

Comments
 (0)