Skip to content

Commit

Permalink
make symbol generation optional (onnx#3599)
Browse files Browse the repository at this point in the history
* make symbol generation optional

Signed-off-by: Ashwini Khade <[email protected]>

* plus more updates

Signed-off-by: Ashwini Khade <[email protected]>
  • Loading branch information
askhade authored Jul 21, 2021
1 parent 41afdc0 commit a57bc99
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 31 deletions.
51 changes: 34 additions & 17 deletions onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ static void InferShapesImpl(
const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name,
const std::unordered_map<std::string, int>& opset_imports,
const ShapeInferenceOptions& options,
SymbolTable& symbolTable,
SymbolTable* symbolTable,
const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
const int ir_version = IR_VERSION // default the latest one
) {
Expand Down Expand Up @@ -356,7 +356,7 @@ static void InferShapesImpl(
auto domain_version = dit->second;
const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain());
InferenceContextImpl ctx(
n, valueTypesByName, inputDataByName, inputSparseDataByName, generatedShapeDataByName, &graphInferenceContext);
n, valueTypesByName, inputDataByName, inputSparseDataByName, &generatedShapeDataByName, &graphInferenceContext);
if (!schema) {
std::cerr << "Warning: Unsupported operator " << n.op_type() << ". No schema registered for this operator."
<< std::endl;
Expand Down Expand Up @@ -385,7 +385,8 @@ static void InferShapesImpl(
function_opset_imports[opset_import.domain()] = static_cast<int>(opset_import.version());
}

InferShapeForFunctionNode(func_proto, function_opset_imports, schema_registry, ctx, symbolTable, generatedShapeDataByName, options);
InferShapeForFunctionNode(
func_proto, function_opset_imports, schema_registry, ctx, options, symbolTable, &generatedShapeDataByName);
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
ONNX_HANDLE_EXCEPTION([&]() {
Expand Down Expand Up @@ -433,7 +434,11 @@ static void InferShapesImpl(
*iter->second = *inferredType;
}
}
materializeSymbolicShape(inferredType, symbolTable);

if (symbolTable) {
materializeSymbolicShape(inferredType, *symbolTable);
}

// Now we can merge pre-existing and inferred info
mergeShapesAndTypes(*inferredType, existingType);
if (options.enable_data_propagation && schema->has_data_propagation_function()) {
Expand Down Expand Up @@ -475,7 +480,7 @@ void InferShapes(
SymbolTableImpl symbolTable;
traverseGraphsToAddExistingSymbols(*g, symbolTable);
InferShapesImpl(
g, std::unordered_map<std::string, TypeProto*>(0), opset_imports, options, symbolTable, schema_registry);
g, std::unordered_map<std::string, TypeProto*>(0), opset_imports, options, &symbolTable, schema_registry);
}

void InferShapes(
Expand All @@ -494,7 +499,7 @@ void InferShapes(
std::unordered_map<std::string, TypeProto*>(0),
opset_imports,
options,
symbolTable,
&symbolTable,
schema_registry,
m.ir_version());
}
Expand Down Expand Up @@ -533,9 +538,13 @@ void InferShapeForFunctionNode(
const std::unordered_map<std::string, int>& func_opset_imports,
const ISchemaRegistry* schema_registry,
InferenceContext& ctx,
SymbolTable& symbolTable,
std::unordered_map<std::string, TensorShapeProto>& generatedShapeDataByName,
const ShapeInferenceOptions& options) {
const ShapeInferenceOptions& options,
SymbolTable* symbolTable,
std::unordered_map<std::string, TensorShapeProto>* generatedShapeDataByName) {
if (options.enable_data_propagation && generatedShapeDataByName == nullptr) {
fail_shape_inference("Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
}

GraphProto g;
// Get a temporary tensor-shape map
const auto num_func_inputs = func->input_size();
Expand Down Expand Up @@ -625,11 +634,14 @@ void InferShapeForFunctionNode(
vi->set_name(copy_n.output(i));
existingType = vi->mutable_type();
}
materializeSymbolicShape(inferred_output_type, symbolTable);

if (symbolTable) {
materializeSymbolicShape(inferred_output_type, *symbolTable);
}
mergeShapesAndTypes(*inferred_output_type, existingType);
if (options.enable_data_propagation && schema->has_data_propagation_function()) {
DataPropagationContextImpl temp_dataPropagationCtx(
copy_n, temp_valueTypesByName, temp_initializersByName, generatedShapeDataByName);
copy_n, temp_valueTypesByName, temp_initializersByName, *generatedShapeDataByName);
schema->GetDataPropagationFunction()(temp_dataPropagationCtx);
}
// Make merged info available to further inference.
Expand Down Expand Up @@ -658,20 +670,22 @@ void InferShapeForFunctionNode(
const FunctionProto* func,
const ISchemaRegistry* schema_registry,
InferenceContext& ctx,
SymbolTable& symbolTable,
std::unordered_map<std::string, TensorShapeProto>& generatedShapeDataByName,
const ShapeInferenceOptions& options) {
const ShapeInferenceOptions& options,
SymbolTable* symbolTable,
std::unordered_map<std::string, TensorShapeProto>* generatedShapeDataByName) {

std::unordered_map<std::string, int> opset_imports;
for (const auto& opset_import : func->opset_import()) {
opset_imports[opset_import.domain()] = static_cast<int>(opset_import.version());
}
InferShapeForFunctionNode(func, opset_imports, schema_registry, ctx, symbolTable, generatedShapeDataByName, options);

InferShapeForFunctionNode(func, opset_imports, schema_registry, ctx, options, symbolTable, generatedShapeDataByName);
}

std::vector<const TypeProto*> GraphInferencerImpl::doInferencing(
const std::vector<const TypeProto*>& inputTypes,
const std::vector<const TensorProto*>& inputData) {
SymbolTable& symbolTable = getSymbolTable();
SymbolTable* symbolTable = getSymbolTable();
int numInputs = int(inputTypes.size());

if (g_->input_size() != numInputs) {
Expand Down Expand Up @@ -703,7 +717,10 @@ std::vector<const TypeProto*> GraphInferencerImpl::doInferencing(
}
// Even if graphInput doesn't have defined type, it will assign inferredType to it
mergeShapesAndTypes(*inferredInput, graphInput);
materializeSymbolicShape(graphInput, symbolTable);

if (symbolTable) {
materializeSymbolicShape(graphInput, *symbolTable);
}
}

// future: pass inputData into InferShapes either directly, or indirectly by
Expand Down
35 changes: 22 additions & 13 deletions onnx/shape_inference/implementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ struct GraphInferenceContext {
const std::unordered_map<std::string, TypeProto*>&
outer_scope_value_types_by_name_in,
const std::unordered_map<std::string, int> opset_imports_in,
SymbolTable& symbolTable_in,
SymbolTable* symbolTable_in,
const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance())
: outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in},
opset_imports{opset_imports_in},
Expand All @@ -89,7 +89,7 @@ struct GraphInferenceContext {
outer_scope_value_types_by_name;
const std::unordered_map<std::string, int> opset_imports;
const ISchemaRegistry* schema_registry;
SymbolTable& symbolTable;
SymbolTable* symbolTable;
};

class GraphInferencerImpl : public GraphInferencer {
Expand All @@ -101,7 +101,7 @@ class GraphInferencerImpl : public GraphInferencer {
const std::vector<const TypeProto*>& inputTypes,
const std::vector<const TensorProto*>& inputData) override;

SymbolTable& getSymbolTable() {
SymbolTable* getSymbolTable() {
return context_->symbolTable;
}

Expand All @@ -118,7 +118,7 @@ struct InferenceContextImpl : public InferenceContext {
inputDataByName,
const std::unordered_map<std::string, const SparseTensorProto*>&
inputSparseDataByName,
const std::unordered_map<std::string, TensorShapeProto>& generatedShapeData,
const std::unordered_map<std::string, TensorShapeProto>* generatedShapeData = nullptr,
GraphInferenceContext* graphInferenceContext = nullptr)
: graphInferenceContext_{graphInferenceContext} {
for (auto& attr : *n.mutable_attribute()) {
Expand All @@ -137,6 +137,10 @@ struct InferenceContextImpl : public InferenceContext {
allInputTypes_.push_back(nullptr);
}

// input data can be in 1 of the 3 containers
// inputDataByName - this is when input is TensorProto
// inputSparseDataByName - this is when input is SparseTesnorProto
// generatedShapeData - this is when input was geenrated as part of partial data propagation
const auto inputDataIter = inputDataByName.find(input);
if (inputDataIter != inputDataByName.cend()) {
allInputData_.push_back(inputDataIter->second);
Expand All @@ -150,9 +154,13 @@ struct InferenceContextImpl : public InferenceContext {
allShapeInputData_.push_back(nullptr);
} else {
allInputSparseData_.push_back(nullptr);
const auto inputShapeDataIter = generatedShapeData.find(input);
if (inputShapeDataIter != generatedShapeData.cend()) {
allShapeInputData_.push_back(&inputShapeDataIter->second);
if (generatedShapeData != nullptr) {
const auto inputShapeDataIter = generatedShapeData->find(input);
if (inputShapeDataIter != generatedShapeData->cend()) {
allShapeInputData_.push_back(&inputShapeDataIter->second);
} else {
allShapeInputData_.push_back(nullptr);
}
} else {
allShapeInputData_.push_back(nullptr);
}
Expand Down Expand Up @@ -446,18 +454,19 @@ void InferShapeForFunctionNode(
const FunctionProto* func,
const ISchemaRegistry* schema_registry,
InferenceContext& ctx,
SymbolTable& symbolTable,
std::unordered_map<std::string, TensorShapeProto>& generatedShapeDataByName,
const ShapeInferenceOptions& options);
const ShapeInferenceOptions& options = {},
SymbolTable* symbolTable = nullptr,
std::unordered_map<std::string, TensorShapeProto>* generatedShapeDataByName = nullptr);

void InferShapeForFunctionNode(
const FunctionProto* func,
const std::unordered_map<std::string, int>& func_opset_imports,
const ISchemaRegistry* schema_registry,
InferenceContext& ctx,
SymbolTable& symbolTable,
std::unordered_map<std::string, TensorShapeProto>& generatedShapeDataByName,
const ShapeInferenceOptions& options);
const ShapeInferenceOptions& options = {},
SymbolTable* symbolTable = nullptr,
std::unordered_map<std::string, TensorShapeProto>* generatedShapeDataByName = nullptr);


std::string getErrorWithNodeInfo(NodeProto n, std::runtime_error err);

Expand Down
2 changes: 1 addition & 1 deletion onnx/test/cpp/shape_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ static void doInferencingTest(bool use_scan_opset8) {
const std::unordered_map<std::string, TypeProto*> outer_scope_value_types;
SymbolTableImpl symbolTable;
symbolTable.addFromGraph(subgraph);
GraphInferenceContext graphInfCtx(outer_scope_value_types, opset_imports, symbolTable);
GraphInferenceContext graphInfCtx(outer_scope_value_types, opset_imports, &symbolTable);
GraphInferencerImpl graphInferencer(subgraph, graphInfCtx);

// loop_state_in and scan_in are the two inputs.
Expand Down

0 comments on commit a57bc99

Please sign in to comment.