39
39
import org .tensorflow .op .CustomGradient ;
40
40
import org .tensorflow .op .RawCustomGradient ;
41
41
import org .tensorflow .op .RawOpInputs ;
42
- import org .tensorflow .op .annotation .GeneratedOpInputsMetadata ;
43
- import org .tensorflow .op .annotation .GeneratedOpMetadata ;
42
+ import org .tensorflow .op .annotation .OpInputsMetadata ;
43
+ import org .tensorflow .op .annotation .OpMetadata ;
44
44
import org .tensorflow .op .math .Add ;
45
45
import org .tensorflow .proto .framework .OpList ;
46
46
@@ -187,35 +187,34 @@ public static synchronized boolean registerCustomGradient(
187
187
}
188
188
189
189
/**
190
- * Register a custom gradient function for ops of {@code opClass} type. The actual op type is
191
- * detected from the class's {@link GeneratedOpMetadata } annotation. As such, it only works on
192
- * generated op classes.
190
+ * Register a custom gradient function for ops of {@code inputClass}'s op type. The actual op type
191
+ * is detected from the class's {@link OpInputsMetadata } annotation. As such, it only works on
192
+ * generated op classes or custom op classes with the correct annotations .
193
193
*
194
- * @param opClass the class of op to register the gradient for.
194
+ * @param inputClass the inputs class of op to register the gradient for.
195
195
* @param gradient the gradient function to use
196
196
* @return {@code true} if the gradient was registered, {@code false} if there was already a
197
197
* gradient registered for this op
198
- * @throws IllegalArgumentException if {@code opClass} does not have a {@link GeneratedOpMetadata}
199
- * field .
198
+ * @throws IllegalArgumentException if {@code inputClass} is not annotated with {@link
199
+ * OpInputsMetadata} or the op class is not annotated with {@link OpMetadata} .
200
200
*/
201
201
public static synchronized <T extends RawOpInputs <?>> boolean registerCustomGradient (
202
- Class <T > opClass , CustomGradient <T > gradient ) {
203
- GeneratedOpInputsMetadata metadata = opClass .getAnnotation (GeneratedOpInputsMetadata .class );
202
+ Class <T > inputClass , CustomGradient <T > gradient ) {
203
+ OpInputsMetadata metadata = inputClass .getAnnotation (OpInputsMetadata .class );
204
204
205
205
if (metadata == null ) {
206
206
throw new IllegalArgumentException (
207
207
"Inputs Class "
208
- + opClass
209
- + " does not have a GeneratedOpInputsMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug." );
208
+ + inputClass
209
+ + " does not have a OpInputsMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug." );
210
210
}
211
- GeneratedOpMetadata outputMetadata =
212
- metadata .outputsClass ().getAnnotation (GeneratedOpMetadata .class );
211
+ OpMetadata outputMetadata = metadata .outputsClass ().getAnnotation (OpMetadata .class );
213
212
214
213
if (outputMetadata == null ) {
215
214
throw new IllegalArgumentException (
216
215
"Op Class "
217
216
+ metadata .outputsClass ()
218
- + " does not have a GeneratedOpMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug." );
217
+ + " does not have a OpMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug." );
219
218
}
220
219
221
220
String opType = outputMetadata .opType ();
@@ -224,7 +223,7 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
224
223
return false ;
225
224
}
226
225
227
- GradFunc g = CustomGradient .adapter (gradient , opClass );
226
+ GradFunc g = CustomGradient .adapter (gradient , inputClass );
228
227
GradOpRegistry .Global ().Register (opType , g );
229
228
gradientFuncs .add (g );
230
229
return true ;
0 commit comments