Skip to content

Commit f1f6e8b

Browse files
committed
Update annotation names and comments, and registerCustomGradient javadoc
Signed-off-by: Ryan Nett <[email protected]>
1 parent 2eb9342 commit f1f6e8b

File tree

5 files changed

+32
-32
lines changed

5 files changed

+32
-32
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java

+15-16
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
import org.tensorflow.op.CustomGradient;
4040
import org.tensorflow.op.RawCustomGradient;
4141
import org.tensorflow.op.RawOpInputs;
42-
import org.tensorflow.op.annotation.GeneratedOpInputsMetadata;
43-
import org.tensorflow.op.annotation.GeneratedOpMetadata;
42+
import org.tensorflow.op.annotation.OpInputsMetadata;
43+
import org.tensorflow.op.annotation.OpMetadata;
4444
import org.tensorflow.op.math.Add;
4545
import org.tensorflow.proto.framework.OpList;
4646

@@ -187,35 +187,34 @@ public static synchronized boolean registerCustomGradient(
187187
}
188188

189189
/**
190-
* Register a custom gradient function for ops of {@code opClass} type. The actual op type is
191-
* detected from the class's {@link GeneratedOpMetadata} annotation. As such, it only works on
192-
* generated op classes.
190+
* Register a custom gradient function for ops of {@code inputClass}'s op type. The actual op type
191+
* is detected from the class's {@link OpInputsMetadata} annotation. As such, it only works on
192+
* generated op classes or custom op classes with the correct annotations.
193193
*
194-
* @param opClass the class of op to register the gradient for.
194+
* @param inputClass the inputs class of op to register the gradient for.
195195
* @param gradient the gradient function to use
196196
* @return {@code true} if the gradient was registered, {@code false} if there was already a
197197
* gradient registered for this op
198-
* @throws IllegalArgumentException if {@code opClass} does not have a {@link GeneratedOpMetadata}
199-
* field.
198+
* @throws IllegalArgumentException if {@code inputClass} is not annotated with {@link
199+
* OpInputsMetadata} or the op class is not annotated with {@link OpMetadata}.
200200
*/
201201
public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGradient(
202-
Class<T> opClass, CustomGradient<T> gradient) {
203-
GeneratedOpInputsMetadata metadata = opClass.getAnnotation(GeneratedOpInputsMetadata.class);
202+
Class<T> inputClass, CustomGradient<T> gradient) {
203+
OpInputsMetadata metadata = inputClass.getAnnotation(OpInputsMetadata.class);
204204

205205
if (metadata == null) {
206206
throw new IllegalArgumentException(
207207
"Inputs Class "
208-
+ opClass
209-
+ " does not have a GeneratedOpInputsMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug.");
208+
+ inputClass
209+
+ " does not have a OpInputsMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug.");
210210
}
211-
GeneratedOpMetadata outputMetadata =
212-
metadata.outputsClass().getAnnotation(GeneratedOpMetadata.class);
211+
OpMetadata outputMetadata = metadata.outputsClass().getAnnotation(OpMetadata.class);
213212

214213
if (outputMetadata == null) {
215214
throw new IllegalArgumentException(
216215
"Op Class "
217216
+ metadata.outputsClass()
218-
+ " does not have a GeneratedOpMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug.");
217+
+ " does not have a OpMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug.");
219218
}
220219

221220
String opType = outputMetadata.opType();
@@ -224,7 +223,7 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
224223
return false;
225224
}
226225

227-
GradFunc g = CustomGradient.adapter(gradient, opClass);
226+
GradFunc g = CustomGradient.adapter(gradient, inputClass);
228227
GradOpRegistry.Global().Register(opType, g);
229228
gradientFuncs.add(g);
230229
return true;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/annotation/GeneratedOpInputsMetadata.java tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/annotation/OpInputsMetadata.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
import org.tensorflow.op.RawOp;
2525

2626
/**
27-
* An annotation that should only be used by codegeneration. Used to provide some metadata about the
28-
* op.
29-
*
30-
* <p><b>DO NOT USE MANUALLY</b>
27+
* An annotation to provide metadata about an op inputs accessor class. Should only be used by users
28+
* on custom ops, will be generated for non-custom ops.
3129
*/
3230
@Target(ElementType.TYPE)
3331
@Retention(RetentionPolicy.RUNTIME)
34-
public @interface GeneratedOpInputsMetadata {
32+
public @interface OpInputsMetadata {
33+
34+
/** The main op class. */
3535
Class<? extends RawOp> outputsClass();
3636
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/annotation/GeneratedOpMetadata.java tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/annotation/OpMetadata.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,17 @@
2424
import org.tensorflow.op.RawOpInputs;
2525

2626
/**
27-
* An annotation that should only be used by codegeneration. Used to provide some metadata about the
28-
* op.
29-
*
30-
* <p><b>DO NOT USE MANUALLY</b>
27+
* An annotation to provide metadata about an op. Should only be used by users on custom ops, will
28+
* be generated for non-custom ops.
3129
*/
3230
@Target(ElementType.TYPE)
3331
@Retention(RetentionPolicy.RUNTIME)
34-
public @interface GeneratedOpMetadata {
32+
public @interface OpMetadata {
33+
34+
/** The type of the op in the TF runtime. */
3535
String opType();
3636

37+
/** The typesafe inputs class (which should be annotated with {@link OpInputsMetadata}). */
3738
@SuppressWarnings("rawtypes")
3839
Class<? extends RawOpInputs> inputsClass();
3940
}

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ public class Names {
2929

3030
public static final ClassName Operator = ClassName.get(OpPackage + ".annotation", "Operator");
3131
public static final ClassName Endpoint = ClassName.get(OpPackage + ".annotation", "Endpoint");
32-
public static final ClassName GeneratedOpMetadata =
33-
ClassName.get(OpPackage + ".annotation", "GeneratedOpMetadata");
34-
public static final ClassName GeneratedOpInputsMetadata =
35-
ClassName.get(OpPackage + ".annotation", "GeneratedOpInputsMetadata");
32+
public static final ClassName OpMetadata =
33+
ClassName.get(OpPackage + ".annotation", "OpMetadata");
34+
public static final ClassName OpInputsMetadata =
35+
ClassName.get(OpPackage + ".annotation", "OpInputsMetadata");
3636

3737
public static final ClassName TType = ClassName.get(TypesPackage + ".family", "TType");
3838
public static final ClassName TString = ClassName.get(TypesPackage, "TString");

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ private Set<TypeVariableName> buildInputsClass() {
10571057
/** Adds the GeneratedOpMetadata annotation to the op class. */
10581058
private void addInputsMetadataAnnotation() {
10591059
builder.addAnnotation(
1060-
AnnotationSpec.builder(Names.GeneratedOpMetadata)
1060+
AnnotationSpec.builder(Names.OpMetadata)
10611061
.addMember("opType", "$L", className + ".OP_NAME")
10621062
.addMember("inputsClass", "$T.class", inputsClassName())
10631063
.build());
@@ -1066,7 +1066,7 @@ private void addInputsMetadataAnnotation() {
10661066
/** Adds the GeneratedOpInputsMetadata annotation to the op input class. */
10671067
private void addInputsMetadataAnnotation(TypeSpec.Builder inputsBuilder) {
10681068
inputsBuilder.addAnnotation(
1069-
AnnotationSpec.builder(Names.GeneratedOpInputsMetadata)
1069+
AnnotationSpec.builder(Names.OpInputsMetadata)
10701070
.addMember("outputsClass", "$T.class", className())
10711071
.build());
10721072
}

0 commit comments

Comments
 (0)