1
1
/*
2
- Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3
3
4
- Licensed under the Apache License, Version 2.0 (the "License");
5
- you may not use this file except in compliance with the License.
6
- You may obtain a copy of the License at
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
7
8
- http://www.apache.org/licenses/LICENSE-2.0
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
9
10
- Unless required by applicable law or agreed to in writing, software
11
- distributed under the License is distributed on an "AS IS" BASIS,
12
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- See the License for the specific language governing permissions and
14
- limitations under the License.
15
- ==============================================================================
16
- */
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ ==============================================================================
16
+ */
17
17
package org .tensorflow ;
18
18
19
+ import static org .tensorflow .internal .c_api .global .tensorflow .OutputsFromTFOutputs ;
20
+ import static org .tensorflow .internal .c_api .global .tensorflow .TFOutputsFromOutputs ;
19
21
import static org .tensorflow .internal .c_api .global .tensorflow .ToOperation ;
20
22
21
23
import java .util .ArrayList ;
22
24
import java .util .List ;
23
25
import org .bytedeco .javacpp .PointerScope ;
24
- import org .tensorflow .internal .c_api .NativeOutput ;
25
26
import org .tensorflow .internal .c_api .NativeOutputVector ;
26
27
import org .tensorflow .internal .c_api .Node ;
28
+ import org .tensorflow .internal .c_api .TF_Output ;
27
29
28
30
/**
29
31
* Helpers for {@link org.tensorflow.op.TypedGradientAdapter} and {@link
@@ -34,39 +36,49 @@ public class GradientAdapterHelpers {
34
36
/**
35
37
* Convert a array of native outputs to a list of {@link Output}s.
36
38
*
37
- * @param g the graph the outputs are in
39
+ * @param g the graph the outputs are in
38
40
* @param nativeOutputs the native outputs to convert
39
41
*/
40
42
public static List <Output <?>> fromNativeOutputs (Graph g , NativeOutputVector nativeOutputs ) {
43
+ TF_Output outputs = new TF_Output (nativeOutputs .size ());
44
+ TFOutputsFromOutputs (nativeOutputs , outputs );
41
45
List <Output <?>> gradInputs = new ArrayList <>((int ) nativeOutputs .size ());
42
46
for (int i = 0 ; i < nativeOutputs .size (); i ++) {
43
- NativeOutput output = nativeOutputs .get (i );
44
- gradInputs .add (new Output <>(getGraphOp (g , output .node ()),
45
- output .index ()));
47
+ TF_Output output = outputs .position (i );
48
+ gradInputs .add (new Output <>(new GraphOperation (g , output .oper ()), output .index ()));
46
49
}
47
50
return gradInputs ;
48
51
}
49
52
50
53
/**
51
54
* Put the Java outputs into the array of native outputs, resizing it to the necessary size.
52
55
*
53
- * @param outputs the outputs to put
56
+ * @param outputs the outputs to put
54
57
* @param nativeOutputs the native array to put the outputs into
55
58
*/
56
- public static void putToNativeOutputs (List < Operand <?>> outputs ,
57
- NativeOutputVector nativeOutputs ) {
59
+ public static void putToNativeOutputs (
60
+ List < Operand <?>> outputs , NativeOutputVector nativeOutputs ) {
58
61
nativeOutputs .resize (outputs .size ());
62
+
63
+ TF_Output tempOutputs = new TF_Output (outputs .size ());
59
64
for (int i = 0 ; i < outputs .size (); i ++) {
60
65
Output <?> output = outputs .get (i ).asOutput ();
61
- Node node = ((GraphOperation ) output .op ()).getUnsafeNativeHandle ().node ();
62
- nativeOutputs .put (i , new NativeOutput (node , output .index ()));
66
+ GraphOperation graphOp = (GraphOperation ) output .op ();
67
+ tempOutputs
68
+ .position (i )
69
+ .put (new TF_Output ().oper (graphOp .getUnsafeNativeHandle ()).index (output .index ()));
70
+ }
71
+
72
+ NativeOutputVector temp = OutputsFromTFOutputs (tempOutputs , outputs .size ());
73
+ for (int i = 0 ; i < outputs .size (); i ++) {
74
+ nativeOutputs .put (i , temp .get (i ));
63
75
}
64
76
}
65
77
66
78
/**
67
79
* Make a {@link GraphOperation} from a native {@link Node}
68
80
*
69
- * @param g the graph the operation is in
81
+ * @param g the graph the operation is in
70
82
* @param node the native node
71
83
* @return a graph operation with the underlying native node
72
84
*/
0 commit comments