Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Use setPropertyHostFunction for tensor::topk (9/n)
Browse files Browse the repository at this point in the history
Summary: Migrate tensor::topk to the new style with `setPropertyHostFunction` for better code uniformity and efficiency.

Reviewed By: chrisklaiber

Differential Revision: D36997428

fbshipit-source-id: ba4f920a60f39d3e3dd13dc7c78e9e73a079114d
  • Loading branch information
ta211 authored and facebook-github-bot committed Jun 10, 2022
1 parent 4bd05a9 commit 2d4cb75
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ namespace torch {

// TensorHostObject Method Names
static const std::string SIZE = "size";
static const std::string TOPK = "topk";
static const std::string TOSTRING = "toString";

// TensorHostObject Property Names
Expand All @@ -32,7 +31,7 @@ static const std::string SHAPE = "shape";
static const std::vector<std::string> PROPERTIES = {DATA, DTYPE, SHAPE};

// TensorHostObject Methods
static const std::vector<std::string> METHODS = {SIZE, TOPK, TOSTRING};
static const std::vector<std::string> METHODS = {SIZE, TOSTRING};

using namespace facebook;

Expand Down Expand Up @@ -449,6 +448,29 @@ jsi::Value toImpl(
return jsi::Object::createFromHostObject(runtime, tensorHostObject);
};

jsi::Value topkImpl(
jsi::Runtime& runtime,
const jsi::Value& thisValue,
const jsi::Value* arguments,
size_t count) {
utils::ArgumentParser args(runtime, thisValue, arguments, count);
args.requireNumArguments(1);

auto k = args[0].asNumber();
auto resultTuple = args.thisAsHostObject<TensorHostObject>()->tensor.topk(k);
auto values = utils::helpers::createFromHostObject<TensorHostObject>(
runtime, std::get<0>(resultTuple));
/**
* NOTE: We need to convert the int64 type to int32 since Hermes does not
* support Int64 data types yet.
*/
auto indicesInt64Tensor = std::get<1>(resultTuple);
auto indices = utils::helpers::createFromHostObject<TensorHostObject>(
runtime, indicesInt64Tensor.to(c10::ScalarType::Int));

return jsi::Array::createWithElements(runtime, values, indices);
};

jsi::Value unsqueezeImpl(
jsi::Runtime& runtime,
const jsi::Value& thisValue,
Expand All @@ -470,7 +492,6 @@ jsi::Value unsqueezeImpl(
TensorHostObject::TensorHostObject(jsi::Runtime& runtime, torch_::Tensor t)
: BaseHostObject(runtime),
size_(createSize(runtime)),
topk_(createTopK(runtime)),
toString_(createToString(runtime)),
tensor(t) {
setPropertyHostFunction(runtime, "abs", 0, absImpl);
Expand All @@ -489,6 +510,7 @@ TensorHostObject::TensorHostObject(jsi::Runtime& runtime, torch_::Tensor t)
setPropertyHostFunction(runtime, "stride", 0, strideImpl);
setPropertyHostFunction(runtime, "sub", 1, subImpl);
setPropertyHostFunction(runtime, "to", 1, toImpl);
setPropertyHostFunction(runtime, "topk", 1, topkImpl);
setPropertyHostFunction(runtime, "unsqueeze", 1, unsqueezeImpl);
}

Expand Down Expand Up @@ -520,8 +542,6 @@ jsi::Value TensorHostObject::get(
return this->size_.call(runtime);
} else if (name == SIZE) {
return jsi::Value(runtime, size_);
} else if (name == TOPK) {
return jsi::Value(runtime, topk_);
} else if (name == TOSTRING) {
return jsi::Value(runtime, toString_);
}
Expand Down Expand Up @@ -583,45 +603,5 @@ jsi::Function TensorHostObject::createSize(jsi::Runtime& runtime) {
runtime, jsi::PropNameID::forUtf8(runtime, SIZE), 0, sizeFunc);
}

/**
* Returns the k largest elements of the given input tensor along a given
* dimension.
*
* https://pytorch.org/docs/stable/generated/torch.topk.html
*/
jsi::Function TensorHostObject::createTopK(jsi::Runtime& runtime) {
auto topkFunc = [this](
jsi::Runtime& runtime,
const jsi::Value& thisValue,
const jsi::Value* arguments,
size_t count) {
if (count < 1) {
throw jsi::JSError(runtime, "This function requires at least 1 argument");
}
auto k = arguments[0].asNumber();
auto resultTuple = this->tensor.topk(k);
auto outputValuesTensorHostObject =
std::make_shared<torchlive::torch::TensorHostObject>(
runtime, std::get<0>(resultTuple));
auto indicesInt64Tensor = std::get<1>(resultTuple);
/**
* NOTE: We need to convert the int64 type to int32 since Hermes does not
* support Int64 data types yet.
*/
auto outputIndicesTensorHostObject =
std::make_shared<torchlive::torch::TensorHostObject>(
runtime, indicesInt64Tensor.to(c10::ScalarType::Int));
return jsi::Array::createWithElements(
runtime,
jsi::Object::createFromHostObject(
runtime, outputValuesTensorHostObject),
jsi::Object::createFromHostObject(
runtime, outputIndicesTensorHostObject));
};

return jsi::Function::createFromHostFunction(
runtime, jsi::PropNameID::forUtf8(runtime, TOPK), 1, topkFunc);
}

} // namespace torch
} // namespace torchlive
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ namespace torch {

class JSI_EXPORT TensorHostObject : public common::BaseHostObject {
facebook::jsi::Function size_;
facebook::jsi::Function topk_;
facebook::jsi::Function toString_;

public:
Expand All @@ -42,7 +41,6 @@ class JSI_EXPORT TensorHostObject : public common::BaseHostObject {
private:
facebook::jsi::Function createSize(facebook::jsi::Runtime& runtime);
facebook::jsi::Function createToString(facebook::jsi::Runtime& runtime);
facebook::jsi::Function createTopK(facebook::jsi::Runtime& runtime);
};

} // namespace torch
Expand Down

0 comments on commit 2d4cb75

Please sign in to comment.