forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpython_saved_variable_hooks.h
36 lines (29 loc) · 1.05 KB
/
python_saved_variable_hooks.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#pragma once
#include <ATen/ATen.h>
#include <c10/core/SafePyObject.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
namespace py = pybind11;
namespace torch::autograd {
struct PySavedVariableHooks : public SavedVariableHooks {
PySavedVariableHooks(py::function& pack_hook, py::function& unpack_hook);
void call_pack_hook(const at::Tensor& tensor) override;
at::Tensor call_unpack_hook() override;
~PySavedVariableHooks() override;
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
retrieve_unpack_hook_data() const override;
private:
PyObject* pack_hook_;
PyObject* unpack_hook_;
PyObject* data_ = nullptr;
};
struct PyDefaultSavedVariableHooks {
static void push_hooks(py::function& pack_hook, py::function& unpack_hook);
static void pop_hooks();
static std::unique_ptr<SavedVariableHooks> get_hooks();
};
} // namespace torch::autograd