Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map experimental C (actually C++) API for gradient tape #283

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
<javacpp.platform.macosx-x86_64.extension>macosx-x86_64${javacpp.platform.extension}</javacpp.platform.macosx-x86_64.extension>
<javacpp.platform.windows-x86.extension>windows-x86${javacpp.platform.extension}</javacpp.platform.windows-x86.extension>
<javacpp.platform.windows-x86_64.extension>windows-x86_64${javacpp.platform.extension}</javacpp.platform.windows-x86_64.extension>
<javacpp.version>1.5.4</javacpp.version>
<javacpp.version>1.5.5</javacpp.version>
</properties>

<profiles>
Expand Down
25 changes: 25 additions & 0 deletions tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<id>javacpp-parser</id>
<phase>generate-sources</phase>
<goals>
<goal>resources</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.0</version>
Expand Down Expand Up @@ -211,7 +224,15 @@
<classPath>${project.build.outputDirectory}</classPath>
<includePaths>
<includePath>${project.basedir}/</includePath>
<includePath>${project.basedir}/bazel-bin/external/llvm-project/llvm/include/</includePath>
<includePath>${project.basedir}/bazel-bin/external/org_tensorflow/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/eigen_archive/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_absl/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_protobuf/src/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/farmhash_archive/src/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/llvm-project/llvm/include/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/</includePath>
<includePath>${project.basedir}/target/classes/org/tensorflow/internal/c_api/include/</includePath>
</includePaths>
<linkPaths>
<linkPath>${project.basedir}/bazel-bin/external/llvm_openmp/</linkPath>
Expand Down Expand Up @@ -317,6 +338,10 @@
<outputDirectory>${project.build.directory}/native/org/tensorflow/internal/c_api/${native.classifier}/</outputDirectory>
<skip>${javacpp.compiler.skip}</skip>
<classOrPackageName>org.tensorflow.internal.c_api.**</classOrPackageName>
<compilerOptions>
<!-- TODO: Remove files from here as they get integrated into the Bazel build -->
<compilerOption>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/tensorflow/c/eager/gradients.cc</compilerOption>
</compilerOptions>
<copyLibs>true</copyLibs>
<copyResources>true</copyResources>
</configuration>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;


// Abstract interface to a context.
//
// This serves as a factory for creating `AbstractOperation`s and for
// registering traced functions.
// Operations creation within a context can only be executed in that context
// (for now at least).
// Implementations of the context may contain some state e.g. an execution
// environment, a traced representation etc.
@Namespace("tensorflow") @NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class AbstractContext extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public AbstractContext(Pointer p) { super(p); }

public native int getKind();

// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
public native void Release();

// Creates an operation builder and ties it to this context.
// The returned object can be used for setting operation's attributes,
// adding inputs and finally executing (immediately or lazily as in tracing)
// it in this context.
public native AbstractOperation CreateOperation();

// Registers a function with this context, after this the function is
// available to be called/referenced by its name in this context.
public native @ByVal Status RegisterFunction(AbstractFunction arg0);
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
public native @ByVal Status RemoveFunction(@StdString BytePointer func);
public native @ByVal Status RemoveFunction(@StdString String func);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;

@Namespace("tensorflow::internal") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class AbstractContextDeleter extends Pointer {
static { Loader.load(); }
/** Default native constructor. */
public AbstractContextDeleter() { super((Pointer)null); allocate(); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public AbstractContextDeleter(long size) { super((Pointer)null); allocateArray(size); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public AbstractContextDeleter(Pointer p) { super(p); }
private native void allocate();
private native void allocateArray(long size);
@Override public AbstractContextDeleter position(long position) {
return (AbstractContextDeleter)super.position(position);
}
@Override public AbstractContextDeleter getPointer(long i) {
return new AbstractContextDeleter((Pointer)this).position(position + i);
}

public native @Name("operator ()") void apply(AbstractContext p);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;


// A traced function: this hides the complexity of converting the serialized
// representation between various supported formats e.g. FunctionDef and Mlir
// function.
@Namespace("tensorflow") @NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class AbstractFunction extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public AbstractFunction(Pointer p) { super(p); }

// Returns which subclass is this instance of.
public native int getKind();

// Returns the AbstractFunction as a FunctionDef.
public native @ByVal Status GetFunctionDef(@Cast("tensorflow::FunctionDef**") PointerPointer arg0);
public native @ByVal Status GetFunctionDef(@Cast("tensorflow::FunctionDef**") @ByPtrPtr Pointer arg0);
}
Loading