Skip to content

Commit dab1777

Browse files
committed
Map Rust enums with inner fields
1 parent 7d56adb commit dab1777

File tree

59 files changed

+2073
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+2073
-2
lines changed

genbindings.py

+118-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,32 @@ def print_name(self):
4848
trait_structs = set()
4949
unitary_enums = set()
5050

51+
def camel_to_snake(s):
52+
# Convert camel case to snake case, in a way that appears to match cbindgen
53+
con = "_"
54+
ret = ""
55+
lastchar = ""
56+
lastund = False
57+
for char in s:
58+
if lastchar.isupper():
59+
if not char.isupper() and not lastund:
60+
ret = ret + "_"
61+
lastund = True
62+
else:
63+
lastund = False
64+
ret = ret + lastchar.lower()
65+
else:
66+
ret = ret + lastchar
67+
if char.isupper() and not lastund:
68+
ret = ret + "_"
69+
lastund = True
70+
else:
71+
lastund = False
72+
lastchar = char
73+
if char.isnumeric():
74+
lastund = True
75+
return (ret + lastchar.lower()).strip("_")
76+
5177
var_is_arr_regex = re.compile("\(\*([A-za-z_]*)\)\[([0-9]*)\]")
5278
var_ty_regex = re.compile("([A-za-z_0-9]*)(.*)")
5379
def java_c_types(fn_arg, ret_arr_len):
@@ -548,6 +574,7 @@ def map_trait(struct_name, field_var_lines, trait_fn_lines):
548574
public static native boolean deref_bool(long ptr);
549575
public static native long deref_long(long ptr);
550576
public static native void free_heap_ptr(long ptr);
577+
public static native byte[] read_bytes(long ptr, long len);
551578
public static native byte[] get_u8_slice_bytes(long slice_ptr);
552579
public static native long bytes_to_u8_vec(byte[] bytes);
553580
public static native long vec_slice_len(long vec);
@@ -576,6 +603,11 @@ def map_trait(struct_name, field_var_lines, trait_fn_lines):
576603
JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_free_1heap_1ptr (JNIEnv * env, jclass _a, jlong ptr) {
577604
FREE((void*)ptr);
578605
}
606+
JNIEXPORT jbyteArray JNICALL Java_org_ldk_impl_bindings_read_1bytes (JNIEnv * _env, jclass _b, jlong ptr, jlong len) {
607+
jbyteArray ret_arr = (*_env)->NewByteArray(_env, len);
608+
(*_env)->SetByteArrayRegion(_env, ret_arr, 0, len, (unsigned char*)ptr);
609+
return ret_arr;
610+
}
579611
JNIEXPORT jbyteArray JNICALL Java_org_ldk_impl_bindings_get_1u8_1slice_1bytes (JNIEnv * _env, jclass _b, jlong slice_ptr) {
580612
LDKu8slice *slice = (LDKu8slice*)slice_ptr;
581613
jbyteArray ret_arr = (*_env)->NewByteArray(_env, slice->datalen);
@@ -648,6 +680,7 @@ def map_trait(struct_name, field_var_lines, trait_fn_lines):
648680
assert(struct_alias_regex.match("typedef LDKCResultTempl_bool__PeerHandleError LDKCResult_boolPeerHandleErrorZ;"))
649681

650682
result_templ_structs = set()
683+
union_enum_items = {}
651684
for line in in_h:
652685
if in_block_comment:
653686
#out_java.write("\t" + line)
@@ -708,6 +741,7 @@ def map_trait(struct_name, field_var_lines, trait_fn_lines):
708741
assert(not is_union or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_result or vec_ty is not None))
709742
assert(not is_result or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or vec_ty is not None))
710743
assert(vec_ty is None or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or is_result))
744+
711745
if is_opaque:
712746
opaque_structs.add(struct_name)
713747
out_java.write("\tpublic static native long " + struct_name + "_optional_none();\n")
@@ -724,6 +758,89 @@ def map_trait(struct_name, field_var_lines, trait_fn_lines):
724758
out_c.write("\t" + struct_name + " *vec = (" + struct_name + "*)ptr;\n")
725759
out_c.write("\treturn (*env)->NewObject(env, slicedef_cls, slicedef_meth, (long)vec->data, (long)vec->datalen, sizeof(" + vec_ty + "));\n")
726760
out_c.write("}\n")
761+
elif is_union_enum:
762+
assert(struct_name.endswith("_Tag"))
763+
struct_name = struct_name[:-4]
764+
union_enum_items[struct_name] = {"field_lines": field_lines}
765+
elif struct_name.endswith("_Body") and struct_name.split("_")[0] in union_enum_items:
766+
enum_var_name = struct_name.split("_")
767+
union_enum_items[enum_var_name[0]][enum_var_name[1]] = field_lines
768+
elif struct_name in union_enum_items:
769+
tag_field_lines = union_enum_items[struct_name]["field_lines"]
770+
init_meth_jty_strs = {}
771+
for idx, struct_line in enumerate(tag_field_lines):
772+
if idx == 0:
773+
out_java.write("\tpublic static class " + struct_name + " {\n")
774+
out_java.write("\t\tprivate " + struct_name + "() {}\n")
775+
elif idx == len(tag_field_lines) - 3:
776+
assert(struct_line.endswith("_Sentinel,"))
777+
elif idx == len(tag_field_lines) - 2:
778+
out_java.write("\t\tstatic native void init();\n")
779+
out_java.write("\t}\n")
780+
elif idx == len(tag_field_lines) - 1:
781+
assert(struct_line == "")
782+
else:
783+
var_name = struct_line.strip(' ,')[len(struct_name) + 1:]
784+
out_java.write("\t\tpublic final static class " + var_name + " extends " + struct_name + " {\n")
785+
out_c.write("jclass " + struct_name + "_" + var_name + "_class = NULL;\n")
786+
out_c.write("jmethodID " + struct_name + "_" + var_name + "_meth = NULL;\n")
787+
init_meth_jty_str = ""
788+
init_meth_params = ""
789+
init_meth_body = ""
790+
if "LDK" + var_name in union_enum_items[struct_name]:
791+
enum_var_lines = union_enum_items[struct_name]["LDK" + var_name]
792+
for idx, field in enumerate(enum_var_lines):
793+
if idx != 0 and idx < len(enum_var_lines) - 2:
794+
field_ty = java_c_types(field.strip(' ;'), None)
795+
out_java.write("\t\t\tpublic " + field_ty.java_ty + " " + field_ty.var_name + ";\n")
796+
init_meth_jty_str = init_meth_jty_str + field_ty.java_fn_ty_arg
797+
if idx > 1:
798+
init_meth_params = init_meth_params + ", "
799+
init_meth_params = init_meth_params + field_ty.java_ty + " " + field_ty.var_name
800+
init_meth_body = init_meth_body + "this." + field_ty.var_name + " = " + field_ty.var_name + "; "
801+
out_java.write("\t\t\t" + var_name + "(" + init_meth_params + ") { ")
802+
out_java.write(init_meth_body)
803+
out_java.write("}\n")
804+
out_java.write("\t\t}\n")
805+
init_meth_jty_strs[var_name] = init_meth_jty_str
806+
out_java.write("\tstatic { " + struct_name + ".init(); }\n")
807+
out_java.write("\tpublic static native " + struct_name + " " + struct_name + "_ref_from_ptr(long ptr);\n");
808+
809+
out_c.write("JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_00024" + struct_name.replace("_", "_1") + "_init (JNIEnv * env, jclass _a) {\n")
810+
for idx, struct_line in enumerate(tag_field_lines):
811+
if idx != 0 and idx < len(tag_field_lines) - 3:
812+
var_name = struct_line.strip(' ,')[len(struct_name) + 1:]
813+
out_c.write("\t" + struct_name + "_" + var_name + "_class =\n")
814+
out_c.write("\t\t(*env)->NewGlobalRef(env, (*env)->FindClass(env, \"Lorg/ldk/impl/bindings$" + struct_name + "$" + var_name + ";\"));\n")
815+
out_c.write("\tDO_ASSERT(" + struct_name + "_" + var_name + "_class != NULL);\n")
816+
out_c.write("\t" + struct_name + "_" + var_name + "_meth = (*env)->GetMethodID(env, " + struct_name + "_" + var_name + "_class, \"<init>\", \"(" + init_meth_jty_strs[var_name] + ")V\");\n")
817+
out_c.write("\tDO_ASSERT(" + struct_name + "_" + var_name + "_meth != NULL);\n")
818+
out_c.write("}\n")
819+
out_c.write("JNIEXPORT jobject JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1ref_1from_1ptr (JNIEnv * env, jclass _c, jlong ptr) {\n")
820+
out_c.write("\t" + struct_name + " *obj = (" + struct_name + "*)ptr;\n")
821+
out_c.write("\tswitch(obj->tag) {\n")
822+
for idx, struct_line in enumerate(tag_field_lines):
823+
if idx != 0 and idx < len(tag_field_lines) - 3:
824+
var_name = struct_line.strip(' ,')[len(struct_name) + 1:]
825+
out_c.write("\t\tcase " + struct_name + "_" + var_name + ":\n")
826+
out_c.write("\t\t\treturn (*env)->NewObject(env, " + struct_name + "_" + var_name + "_class, " + struct_name + "_" + var_name + "_meth")
827+
if "LDK" + var_name in union_enum_items[struct_name]:
828+
enum_var_lines = union_enum_items[struct_name]["LDK" + var_name]
829+
out_c.write(",\n\t\t\t\t")
830+
for idx, field in enumerate(enum_var_lines):
831+
if idx != 0 and idx < len(enum_var_lines) - 2:
832+
field_ty = java_c_types(field.strip(' ;'), None)
833+
if idx >= 2:
834+
out_c.write(", ")
835+
if field_ty.is_ptr:
836+
out_c.write("(long)")
837+
elif field_ty.passed_as_ptr or field_ty.arr_len is not None:
838+
out_c.write("(long)&")
839+
out_c.write("obj->" + camel_to_snake(var_name) + "." + field_ty.var_name)
840+
out_c.write("\n\t\t\t")
841+
out_c.write(");\n")
842+
out_c.write("\t\tdefault: abort();\n")
843+
out_c.write("\t}\n}\n")
727844
elif is_unitary_enum:
728845
with open(sys.argv[3] + "/" + struct_name + ".java", "w") as out_java_enum:
729846
out_java_enum.write("package org.ldk.enums;\n\n")
@@ -771,7 +888,7 @@ def map_trait(struct_name, field_var_lines, trait_fn_lines):
771888
for idx, struct_line in enumerate(field_lines):
772889
if idx > 0 and idx < len(field_lines) - 3:
773890
variant = struct_line.strip().strip(",")
774-
out_c.write("\t\tcase " + variant + ": \n")
891+
out_c.write("\t\tcase " + variant + ":\n")
775892
out_c.write("\t\t\treturn (*env)->GetStaticObjectField(env, " + struct_name + "_class, " + struct_name + "_" + variant + ");\n")
776893
ord_v = ord_v + 1
777894
out_c.write("\t\tdefault: abort();\n")

0 commit comments

Comments
 (0)