Skip to content

Commit a41222c

Browse files
cerisierGoogle-ML-Automation
authored andcommitted
PR #22815: [PJRT C API] Ensure C Compliance for all C headers
Imported from GitHub PR #22815 `<cstd*>` headers are C++ headers that wrap their `<std*.h>` counteparts in the std namespace and re-exports them as well. It is meant to be consumed by C++ compilers, not C compilers. Since this is a C API, this PR replaces usages of `<cstd*>` include statements by their C counterparts only for exported C api headers. This PR supersedes #22082 and fixes it across the whole C API. Copybara import of the project: -- d2a1096 by Corentin Kerisit <[email protected]>: [PJRT C API] Ensure C Compliance for all C headers <cstd*> headers are C++ headers that wrap their <std*.h> counteparts in the std namespace and re-exports them as well.. It is meant to be consumed by C++ compilers, not C compilers. Since this is a C API, this PR replaces usages of <cstd*> include statements by their C counterparts only for exported C api headers. This PR supersedes #22082 and fixes it across the whole C API. -- f1f6eb6 by Corentin Kerisit <[email protected]>: Add missing typedef when refering to structs Merging this change closes #22815 COPYBARA_INTEGRATE_REVIEW=#22815 from cerisier:cerisir/fix-c-compatibility f1f6eb6 PiperOrigin-RevId: 728682193
1 parent 6c41035 commit a41222c

7 files changed

+31
-33
lines changed

xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h

+20-20
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ limitations under the License.
1616
#ifndef XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_
1717
#define XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_
1818

19-
#include <cstddef>
20-
#include <cstdint>
19+
#include <stddef.h>
20+
#include <stdint.h>
2121

2222
#include "xla/pjrt/c/pjrt_c_api.h"
2323

@@ -27,24 +27,24 @@ extern "C" {
2727

2828
#define PJRT_API_CUSTOM_PARTITIONER_EXTENSION_VERSION 1
2929

30-
struct JAX_CustomCallPartitioner_string {
30+
typedef struct JAX_CustomCallPartitioner_string {
3131
const char* data;
3232
size_t size;
33-
};
33+
} JAX_CustomCallPartitioner_string;
3434

35-
struct JAX_CustomCallPartitioner_aval {
35+
typedef struct JAX_CustomCallPartitioner_aval {
3636
JAX_CustomCallPartitioner_string shape;
3737
bool has_sharding;
3838
JAX_CustomCallPartitioner_string sharding;
39-
};
39+
} JAX_CustomCallPartitioner_aval;
4040

4141
// General callback information containing api versions, the result error
4242
// message and the cleanup function to free any temporary memory that is backing
4343
// the results. Arguments are always owned by the caller, and results are owned
4444
// by the cleanup_fn. These should never be used directly. Args and results
4545
// should be serialized via the PopulateArgs, ReadArgs, PopulateResults,
4646
// ConsumeResults functions defined below.
47-
struct JAX_CustomCallPartitioner_version_and_error {
47+
typedef struct JAX_CustomCallPartitioner_version_and_error {
4848
int64_t api_version;
4949
void* data; // out
5050
// cleanup_fn cleans up any returned results. The caller must finish with all
@@ -53,9 +53,9 @@ struct JAX_CustomCallPartitioner_version_and_error {
5353
bool has_error;
5454
PJRT_Error_Code code; // out
5555
JAX_CustomCallPartitioner_string error_msg; // out
56-
};
56+
} JAX_CustomCallPartitioner_version_and_error;
5757

58-
struct JAX_CustomCallPartitioner_Partition_Args {
58+
typedef struct JAX_CustomCallPartitioner_Partition_Args {
5959
JAX_CustomCallPartitioner_version_and_error header;
6060

6161
size_t num_args;
@@ -67,9 +67,9 @@ struct JAX_CustomCallPartitioner_Partition_Args {
6767
JAX_CustomCallPartitioner_string mlir_module;
6868
JAX_CustomCallPartitioner_string* args_sharding;
6969
JAX_CustomCallPartitioner_string result_sharding;
70-
};
70+
} JAX_CustomCallPartitioner_Partition_Args;
7171

72-
struct JAX_CustomCallPartitioner_InferShardingFromOperands_Args {
72+
typedef struct JAX_CustomCallPartitioner_InferShardingFromOperands_Args {
7373
JAX_CustomCallPartitioner_version_and_error header;
7474

7575
size_t num_args;
@@ -79,32 +79,32 @@ struct JAX_CustomCallPartitioner_InferShardingFromOperands_Args {
7979

8080
bool has_result_sharding;
8181
JAX_CustomCallPartitioner_string result_sharding;
82-
};
82+
} JAX_CustomCallPartitioner_InferShardingFromOperands_Args;
8383

84-
struct JAX_CustomCallPartitioner_PropagateUserSharding_Args {
84+
typedef struct JAX_CustomCallPartitioner_PropagateUserSharding_Args {
8585
JAX_CustomCallPartitioner_version_and_error header;
8686

8787
JAX_CustomCallPartitioner_string backend_config;
8888

8989
JAX_CustomCallPartitioner_string result_shape;
9090

9191
JAX_CustomCallPartitioner_string result_sharding; // inout
92-
};
92+
} JAX_CustomCallPartitioner_PropagateUserSharding_Args;
9393

94-
struct JAX_CustomCallPartitioner_Callbacks {
94+
typedef struct JAX_CustomCallPartitioner_Callbacks {
9595
int64_t version;
9696
void* private_data;
97-
void (*dtor)(JAX_CustomCallPartitioner_Callbacks* data);
98-
void (*partition)(JAX_CustomCallPartitioner_Callbacks* data,
97+
void (*dtor)(struct JAX_CustomCallPartitioner_Callbacks* data);
98+
void (*partition)(struct JAX_CustomCallPartitioner_Callbacks* data,
9999
JAX_CustomCallPartitioner_Partition_Args* args);
100100
void (*infer_sharding)(
101-
JAX_CustomCallPartitioner_Callbacks* data,
101+
struct JAX_CustomCallPartitioner_Callbacks* data,
102102
JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args);
103103
void (*propagate_user_sharding)(
104-
JAX_CustomCallPartitioner_Callbacks* data,
104+
struct JAX_CustomCallPartitioner_Callbacks* data,
105105
JAX_CustomCallPartitioner_PropagateUserSharding_Args* args);
106106
bool can_side_effecting_have_replicated_sharding;
107-
};
107+
} JAX_CustomCallPartitioner_Callbacks;
108108

109109
struct PJRT_Register_Custom_Partitioner_Args {
110110
size_t struct_size;

xla/pjrt/c/pjrt_c_api_ffi_extension.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ limitations under the License.
1717
#define XLA_PJRT_C_PJRT_C_API_FFI_EXTENSION_H_
1818

1919
#include <stddef.h>
20-
21-
#include <cstdint>
20+
#include <stdint.h>
2221

2322
#include "xla/pjrt/c/pjrt_c_api.h"
2423

@@ -49,11 +48,11 @@ typedef PJRT_Error* PJRT_FFI_TypeID_Register(
4948

5049
// User-data that will be forwarded to the FFI handlers. Deleter is optional,
5150
// and can be nullptr. Deleter will be called when the context is destroyed.
52-
struct PJRT_FFI_UserData {
51+
typedef struct PJRT_FFI_UserData {
5352
int64_t type_id;
5453
void* data;
5554
void (*deleter)(void* data);
56-
};
55+
} PJRT_FFI_UserData;
5756

5857
struct PJRT_FFI_UserData_Add_Args {
5958
size_t struct_size;

xla/pjrt/c/pjrt_c_api_layouts_extension.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ limitations under the License.
1616
#ifndef XLA_PJRT_C_PJRT_C_API_LAYOUTS_EXTENSION_H_
1717
#define XLA_PJRT_C_PJRT_C_API_LAYOUTS_EXTENSION_H_
1818

19-
#include <cstddef>
20-
#include <cstdint>
19+
#include <stddef.h>
20+
#include <stdint.h>
2121

2222
#include "xla/pjrt/c/pjrt_c_api.h"
2323

xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ limitations under the License.
1616
#ifndef XLA_PJRT_C_PJRT_C_API_MEMORY_DESCRIPTIONS_EXTENSION_H_
1717
#define XLA_PJRT_C_PJRT_C_API_MEMORY_DESCRIPTIONS_EXTENSION_H_
1818

19-
#include <cstddef>
19+
#include <stddef.h>
2020

2121
#include "xla/pjrt/c/pjrt_c_api.h"
2222

xla/pjrt/c/pjrt_c_api_profiler_extension.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ limitations under the License.
1616
#ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
1717
#define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
1818

19-
#include <cstddef>
20-
#include <cstdint>
19+
#include <stddef.h>
20+
#include <stdint.h>
2121

2222
#include "xla/backends/profiler/plugin/profiler_c_api.h"
2323
#include "xla/pjrt/c/pjrt_c_api.h"

xla/pjrt/c/pjrt_c_api_stream_extension.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ limitations under the License.
1616
#define XLA_PJRT_C_PJRT_C_API_STREAM_EXTENSION_H_
1717

1818
#include <stddef.h>
19-
20-
#include <cstdint>
19+
#include <stdint.h>
2120

2221
#include "xla/pjrt/c/pjrt_c_api.h"
2322

xla/pjrt/c/pjrt_c_api_triton_extension.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ limitations under the License.
1616
#ifndef XLA_PJRT_C_PJRT_C_API_TRITON_EXTENSION_H_
1717
#define XLA_PJRT_C_PJRT_C_API_TRITON_EXTENSION_H_
1818

19-
#include <cstddef>
20-
#include <cstdint>
19+
#include <stddef.h>
20+
#include <stdint.h>
2121

2222
#include "xla/pjrt/c/pjrt_c_api.h"
2323

0 commit comments

Comments
 (0)