Skip to content

Commit 665f4c2

Browse files
committed
fix: Refactor Stateful Code
1 parent 1a34426 commit 665f4c2

File tree

4 files changed

+387
-323
lines changed

4 files changed

+387
-323
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
3131
std::string& hw_target = session_context_.device_type;
3232
auto enable_causallm = session_context_.enable_causallm;
3333

34-
std::cout << "CausalLM enabled: " << enable_causallm << std::endl;
35-
3634
if (ValidateSubgraph(const_outputs_map_))
3735
return;
3836

onnxruntime/core/providers/openvino/ov_interface.cc

+26-17
Original file line numberDiff line numberDiff line change
@@ -88,38 +88,47 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr<const OVNetwork>& ie_cnn_netwo
8888
// Note! With this default path, the model runs but produces garbage (for NPUW). For CPU it's fine.
8989
auto mutable_model = ie_cnn_network->clone();
9090

91-
std::cout << "stateless model" << std::endl;
92-
logBasicModelInfo(mutable_model);
91+
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
92+
std::cout << "Stateless OV Model Statistic" << std::endl;
93+
LogBasicModelInfo(mutable_model);
94+
}
95+
LogBasicModelInfo(mutable_model);
9396

94-
std::cout << "making stateful..." << std::endl;
95-
patch_stateful_decoder(mutable_model);
97+
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
98+
PatchStatefulDecoder(mutable_model);
9699

97-
std::cout << "after stateful transition:" << std::endl;
98-
logBasicModelInfo(mutable_model);
100+
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
101+
std::cout << "Stateful OV Model Statistic" << std::endl;
102+
LogBasicModelInfo(mutable_model);
103+
}
99104

100105
// This patches the model so that it only produces the logits required for sampling.
101106
// Actually either way that happens within NPUW::LLMCompiledModel creation, but this is
102107
// here mostly to align this behavior for other devices (CPU, GPU).
103-
apply_slice_before_matmul_transformation(mutable_model);
108+
ApplySliceBeforeMatmulTransformation(mutable_model);
104109

105-
auto kv_pos = get_kv_axes_pos(mutable_model);
106-
std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl;
107-
std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl;
110+
auto kv_pos = GetKVAxesPos(mutable_model);
111+
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
112+
std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl;
113+
std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl;
114+
}
108115

109116
if (hw_target.find("NPU") != std::string::npos) {
110117
KVDesc kv_desc;
111-
kv_desc.max_prompt_len = pop_int_and_cast(device_config, "MAX_PROMPT_LEN").value_or(1024u);
112-
kv_desc.min_response_len = pop_int_and_cast(device_config, "MIN_RESPONSE_LEN").value_or(128u);
118+
kv_desc.max_prompt_len = PopIntAndCast(device_config, "MAX_PROMPT_LEN").value_or(1024u);
119+
kv_desc.min_response_len = PopIntAndCast(device_config, "MIN_RESPONSE_LEN").value_or(128u);
113120

114-
std::cout << "kv_desc.max_prompt_len = " << kv_desc.max_prompt_len << std::endl;
115-
std::cout << "kv_desc.min_response_len = " << kv_desc.min_response_len << std::endl;
121+
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
122+
std::cout << "kv_desc.max_prompt_len = " << kv_desc.max_prompt_len << std::endl;
123+
std::cout << "kv_desc.min_response_len = " << kv_desc.min_response_len << std::endl;
124+
}
116125

117-
update_npu_config(config, mutable_model, kv_pos, kv_desc);
126+
UpdateNPUConfig(config, kv_pos, kv_desc);
118127
}
119128

120-
std::cout << "calling compile on stateful model..." << std::endl;
129+
std::cout << "Compiling Stateful OV Model..." << std::endl;
121130
obj = core.compile_model(mutable_model, hw_target, config);
122-
std::cout << "done calling compile on stateful model..." << std::endl;
131+
std::cout << "Stateful OV Model Compilation Complete" << std::endl;
123132
} else {
124133
obj = core.compile_model(ie_cnn_network, hw_target, device_config);
125134
}

0 commit comments

Comments
 (0)