Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Make third-party/ExecuTorchLib's forward() accept multiple inputs #83

Merged
merged 25 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e407016
wip
chmjkb Jan 16, 2025
1f4b875
chore: remove unused function
chmjkb Jan 16, 2025
b5970ea
refactor: remove unintended for loop and data structures
chmjkb Jan 16, 2025
f24d1bb
fix: add error handling & lint
chmjkb Jan 16, 2025
bfeea95
fix: (native) update ETModule's forward to accept multiple inputs
chmjkb Jan 17, 2025
19f4fb5
refactor: get rid of InputType
chmjkb Jan 20, 2025
60828c9
replace single input forward with multiple inputs
chmjkb Jan 22, 2025
8408656
lint
chmjkb Jan 22, 2025
4d904ca
fix: make use of existing functions, return actual output
chmjkb Jan 22, 2025
4ab655f
fix: update rnexecutorchmodules.ts
chmjkb Jan 22, 2025
51c3e71
fix: update BaseModel to match new native implementation, remove Inpu…
chmjkb Jan 22, 2025
ec4d322
fix: int8_T -> char
chmjkb Jan 22, 2025
27af316
feat: make Android accept multiple inputs
chmjkb Jan 23, 2025
b8d71ae
refactor: remove unused function
chmjkb Jan 23, 2025
4daa790
chore: remove useless continue
chmjkb Jan 24, 2025
2719ca4
feat: make use of the new interface in BaseModel
chmjkb Jan 27, 2025
c55bbe8
chore: remove log import
chmjkb Jan 28, 2025
4acdd1b
fix: unsqueeze shapes if a single number is passed
chmjkb Jan 28, 2025
6974d38
chore: add a comment for a hack in forward()
chmjkb Jan 28, 2025
5934228
fix: fix types
chmjkb Jan 28, 2025
55f4b1c
chore: add comment for nsarraytovoidptr
chmjkb Jan 28, 2025
b73c8dc
fix: unsqueeze input before forwarding it to cpp
chmjkb Feb 21, 2025
5e748b3
feat: add multiple inputs to hookless api
chmjkb Feb 24, 2025
3217786
fix: minor fix
chmjkb Feb 24, 2025
2729838
fix: add --noEmit flag to lefthook
chmjkb Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.TensorUtils
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Module
import java.net.URL

class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
private lateinit var module: Module

private var reactApplicationContext = reactContext;
override fun getName(): String {
return NAME
}
Expand All @@ -33,26 +34,40 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
}

override fun forward(
input: ReadableArray,
shape: ReadableArray,
inputType: Double,
inputs: ReadableArray,
shapes: ReadableArray,
inputTypes: ReadableArray,
promise: Promise
) {
val inputEValues = ArrayList<EValue>()
try {
val executorchInput =
TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt())
for (i in 0 until inputs.size()) {
val currentInput = inputs.getArray(i)
?: throw Exception(ETError.InvalidArgument.code.toString())
val currentShape = shapes.getArray(i)
?: throw Exception(ETError.InvalidArgument.code.toString())
val currentInputType = inputTypes.getInt(i)

val result = module.forward(executorchInput)
val resultArray = Arguments.createArray()
val currentEValue = TensorUtils.getExecutorchInput(
currentInput,
ArrayUtils.createLongArray(currentShape),
currentInputType
)

for (evalue in result) {
resultArray.pushArray(ArrayUtils.createReadableArray(evalue.toTensor()))
inputEValues.add(currentEValue)
}

promise.resolve(resultArray)
return
val forwardOutputs = module.forward(*inputEValues.toTypedArray());
val outputArray = Arguments.createArray()

for (output in forwardOutputs) {
val arr = ArrayUtils.createReadableArrayFromTensor(output.toTensor())
outputArray.pushArray(arr)
}
promise.resolve(outputArray)

} catch (e: IllegalArgumentException) {
//The error is thrown when transformation to Tensor fails
// The error is thrown when transformation to Tensor fails
promise.reject("Forward Failed Execution", ETError.InvalidArgument.code.toString())
return
} catch (e: Exception) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,82 +7,52 @@ import org.pytorch.executorch.Tensor

class ArrayUtils {
companion object {
fun createByteArray(input: ReadableArray): ByteArray {
val byteArray = ByteArray(input.size())
for (i in 0 until input.size()) {
byteArray[i] = input.getInt(i).toByte()
}
return byteArray
private inline fun <reified T> createTypedArrayFromReadableArray(input: ReadableArray, transform: (ReadableArray, Int) -> T): Array<T> {
return Array(input.size()) { index -> transform(input, index) }
}

fun createByteArray(input: ReadableArray): ByteArray {
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray()
}
fun createIntArray(input: ReadableArray): IntArray {
val intArray = IntArray(input.size())
for (i in 0 until input.size()) {
intArray[i] = input.getInt(i)
}
return intArray
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index) }.toIntArray()
}

fun createFloatArray(input: ReadableArray): FloatArray {
val floatArray = FloatArray(input.size())
for (i in 0 until input.size()) {
floatArray[i] = input.getDouble(i).toFloat()
}
return floatArray
return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index).toFloat() }.toFloatArray()
}

fun createLongArray(input: ReadableArray): LongArray {
val longArray = LongArray(input.size())
for (i in 0 until input.size()) {
longArray[i] = input.getInt(i).toLong()
}
return longArray
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toLong() }.toLongArray()
}

fun createDoubleArray(input: ReadableArray): DoubleArray {
val doubleArray = DoubleArray(input.size())
for (i in 0 until input.size()) {
doubleArray[i] = input.getDouble(i)
}
return doubleArray
return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index) }.toDoubleArray()
}

fun createReadableArray(result: Tensor): ReadableArray {
fun createReadableArrayFromTensor(result: Tensor): ReadableArray {
val resultArray = Arguments.createArray()

when (result.dtype()) {
DType.UINT8 -> {
val byteArray = result.dataAsByteArray
for (i in byteArray) {
resultArray.pushInt(i.toInt())
}
result.dataAsByteArray.forEach { resultArray.pushInt(it.toInt()) }
}

DType.INT32 -> {
val intArray = result.dataAsIntArray
for (i in intArray) {
resultArray.pushInt(i)
}
result.dataAsIntArray.forEach { resultArray.pushInt(it) }
}

DType.FLOAT -> {
val longArray = result.dataAsFloatArray
for (i in longArray) {
resultArray.pushDouble(i.toDouble())
}
result.dataAsFloatArray.forEach { resultArray.pushDouble(it.toDouble()) }
}

DType.DOUBLE -> {
val floatArray = result.dataAsDoubleArray
for (i in floatArray) {
resultArray.pushDouble(i)
}
result.dataAsDoubleArray.forEach { resultArray.pushDouble(it) }
}

DType.INT64 -> {
val doubleArray = result.dataAsLongArray
for (i in doubleArray) {
resultArray.pushLong(i)
}
// TODO: Do something to handle or deprecate long dtype
// https://github.com/facebook/react-native/issues/12506
result.dataAsLongArray.forEach { resultArray.pushInt(it.toInt()) }
}

else -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,23 @@ class TensorUtils {
fun getExecutorchInput(input: ReadableArray, shape: LongArray, type: Int): EValue {
try {
when (type) {
0 -> {
1 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createByteArray(input), shape)
return EValue.from(inputTensor)
}

1 -> {
3 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createIntArray(input), shape)
return EValue.from(inputTensor)
}

2 -> {
4 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createLongArray(input), shape)
return EValue.from(inputTensor)
}

3 -> {
6 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createFloatArray(input), shape)
return EValue.from(inputTensor)
}

4 -> {
7 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createDoubleArray(input), shape)
return EValue.from(inputTensor)
}
Expand Down
22 changes: 13 additions & 9 deletions ios/RnExecutorch/ETModule.mm
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#import "ETModule.h"
#import <ExecutorchLib/ETModel.h>
#include <Foundation/Foundation.h>
#import <React/RCTBridgeModule.h>
#include <string>

Expand Down Expand Up @@ -36,20 +37,23 @@ - (void)loadModule:(NSString *)modelSource
resolve(result);
}

- (void)forward:(NSArray *)input
shape:(NSArray *)shape
inputType:(double)inputType
- (void)forward:(NSArray *)inputs
shapes:(NSArray *)shapes
inputTypes:(NSArray *)inputTypes
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
@try {
NSArray *result = [module forward:input
shape:shape
inputType:[NSNumber numberWithInt:inputType]];
NSArray *result = [module forward:inputs
shapes:shapes
inputTypes:inputTypes];
resolve(result);
} @catch (NSException *exception) {
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
reject(@"result_error", [NSString stringWithFormat:@"%@", exception.reason],
nil);
NSLog(@"An exception occurred in forward: %@, %@", exception.name,
exception.reason);
reject(
@"forward_error",
[NSString stringWithFormat:@"An error occurred: %@", exception.reason],
nil);
}
}

Expand Down
5 changes: 5 additions & 0 deletions ios/RnExecutorch/models/BaseModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
}

- (NSArray *)forward:(NSArray *)input;

- (NSArray *)forward:(NSArray *)inputs
shapes:(NSArray *)shapes
inputTypes:(NSArray *)inputTypes;

- (void)loadModel:(NSURL *)modelURL
completion:(void (^)(BOOL success, NSNumber *code))completion;

Expand Down
21 changes: 17 additions & 4 deletions ios/RnExecutorch/models/BaseModel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,29 @@
@implementation BaseModel

- (NSArray *)forward:(NSArray *)input {
NSArray *result = [module forward:input
shape:[module getInputShape:@0]
inputType:[module getInputType:@0]];
NSMutableArray *shapes = [NSMutableArray new];
NSMutableArray *inputTypes = [NSMutableArray new];
NSNumber *numberOfInputs = [module getNumberOfInputs];

for (NSUInteger i = 0; i < [numberOfInputs intValue]; i++) {
[shapes addObject:[module getInputShape:[NSNumber numberWithInt:i]]];
[inputTypes addObject:[module getInputType:[NSNumber numberWithInt:i]]];
}

NSArray *result = [module forward:@[input] shapes:shapes inputTypes:inputTypes];
return result;
}

- (NSArray *)forward:(NSArray *)inputs
shapes:(NSArray *)shapes
inputTypes:(NSArray *)inputTypes {
NSArray *result = [module forward:inputs shapes:shapes inputTypes:inputTypes];
return result;
}

- (void)loadModel:(NSURL *)modelURL
completion:(void (^)(BOOL success, NSNumber *code))completion {
module = [[ETModel alloc] init];

NSNumber *result = [self->module loadModel:modelURL.path];
if ([result intValue] != 0) {
completion(NO, result);
Expand Down
2 changes: 1 addition & 1 deletion lefthook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pre-commit:
run: npx eslint {staged_files}
types:
glob: '*.{js,ts, jsx, tsx}'
run: npx tsc
run: npx tsc --noEmit
5 changes: 4 additions & 1 deletion src/hooks/general/useExecutorchModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ export const useExecutorchModule = ({
isReady: boolean;
isGenerating: boolean;
downloadProgress: number;
forward: (input: ETInput, shape: number[]) => Promise<number[][]>;
forward: (
input: ETInput | ETInput[],
shape: number[] | number[][]
) => Promise<number[][]>;
loadMethod: (methodName: string) => Promise<void>;
loadForward: () => Promise<void>;
} => {
Expand Down
53 changes: 45 additions & 8 deletions src/hooks/useModule.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import { useEffect, useState } from 'react';
import { fetchResource } from '../utils/fetchResource';
import { ETError, getError } from '../Error';
import { ETInput, Module, getTypeIdentifier } from '../types/common';
import { ETInput, Module } from '../types/common';
import { _ETModule } from '../native/RnExecutorchModules';

export const getTypeIdentifier = (input: ETInput): number => {
if (input instanceof Int8Array) return 1;
if (input instanceof Int32Array) return 3;
if (input instanceof BigInt64Array) return 4;
if (input instanceof Float32Array) return 6;
if (input instanceof Float64Array) return 7;
return -1;
};

interface Props {
modelSource: string | number;
Expand All @@ -13,7 +23,10 @@ interface _Module {
isReady: boolean;
isGenerating: boolean;
downloadProgress: number;
forwardETInput: (input: ETInput, shape: number[]) => Promise<any>;
forwardETInput: (
input: ETInput[] | ETInput,
shape: number[][] | number[]
) => ReturnType<_ETModule['forward']>;
forwardImage: (input: string) => Promise<any>;
}

Expand Down Expand Up @@ -59,23 +72,47 @@ export const useModule = ({ modelSource, module }: Props): _Module => {
}
};

const forwardETInput = async (input: ETInput, shape: number[]) => {
const forwardETInput = async (
input: ETInput[] | ETInput,
shape: number[][] | number[]
) => {
if (!isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (isGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}

const inputType = getTypeIdentifier(input);
if (inputType === -1) {
throw new Error(getError(ETError.InvalidArgument));
// Since the native module expects an array of inputs and an array of shapes,
// if the user provides a single ETInput, we want to "unsqueeze" the array so
// the data is properly processed on the native side
if (!Array.isArray(input)) {
input = [input];
}

if (!Array.isArray(shape[0])) {
shape = [shape] as number[][];
}

let inputTypeIdentifiers: any[] = [];
let modelInputs: any[] = [];

for (let idx = 0; idx < input.length; idx++) {
let currentInputTypeIdentifier = getTypeIdentifier(input[idx] as ETInput);
if (currentInputTypeIdentifier === -1) {
throw new Error(getError(ETError.InvalidArgument));
}
inputTypeIdentifiers.push(currentInputTypeIdentifier);
modelInputs.push([...(input[idx] as ETInput)]);
}

try {
const numberArray = [...input];
setIsGenerating(true);
const output = await module.forward(numberArray, shape, inputType);
const output = await module.forward(
modelInputs,
shape,
inputTypeIdentifiers
);
setIsGenerating(false);
return output;
} catch (e) {
Expand Down
Loading