Skip to content

Update SDPA to PagedAttention transformation to support phi3 sliding window #29608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class ov::pass::StateManagementPattern : public ov::pass::MatcherPass {
OPENVINO_MATCHER_PASS_RTTI("StateManagementPattern");
StateManagementPattern(ParameterVector& kv_parameters,
ParameterVector& model_remaining_params,
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
ParameterVector& parameters_to_remove,
int& layer_index,
ov::Output<Node> max_context_len,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#include "openvino/op/concat.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/paged_attention.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/select.hpp"
Expand All @@ -28,6 +30,7 @@
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
Expand Down Expand Up @@ -173,6 +176,27 @@ static std::shared_ptr<ov::Node> handle_baichuan2_13b_alibi(
return res_alibi_slopes;
}

static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> handle_phi3_sliding_window() {
using namespace ov::pass::pattern;

auto offset = wrap_type<v0::Constant>();
auto t196 = wrap_type<v1::Add>({any_input(), offset});
auto t197 = pattern::optional<v0::Convert>(t196);
auto t200 = pattern::wrap_type<v4::Range>({t197, any_input(), any_input()});
auto t201 = pattern::wrap_type<v0::Unsqueeze>({t200, any_input()});
auto t202 = pattern::wrap_type<v1::GreaterEqual>({any_input(), t201});
auto t208 = pattern::wrap_type<v1::Select>({t202, any_input(), any_input()});
auto t209 = pattern::wrap_type<v1::Subtract>({any_input(), t208});
auto t210 = pattern::optional<v0::Convert>(t209);
auto t211 = pattern::wrap_type<v1::Select>({t210, any_input(), any_input()});
auto t213 = pattern::wrap_type<v0::Unsqueeze>({t211, any_input()});
auto t214 = pattern::wrap_type<v0::Unsqueeze>({t213, any_input()});
auto t218 = pattern::wrap_type<v3::Broadcast>({t214, any_input()});
auto t219 = pattern::wrap_type<v1::Select>({any_input(), any_input(), t218});
auto mask = pattern::wrap_type<v8::Slice>({t219, any_input(), any_input(), any_input(), any_input()});
return {mask, offset};
}

// Exactly copied the function from another file. Maybe should be moved to some general file
static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> node, const std::string& name) {
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
Expand Down Expand Up @@ -207,7 +231,6 @@ static node_tuple kv_read_and_concat(ov::Output<ov::Node> kv_current) {

ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_parameters,
ParameterVector& model_remaining_params,
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
ParameterVector& parameters_to_remove,
int& layer_index,
Output<Node> max_context_len,
Expand Down Expand Up @@ -297,15 +320,20 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
std::shared_ptr<ov::Node> baichuan2_13b_alibi, baichuan2_13b_alibi_mask;
std::tie(baichuan2_13b_alibi, baichuan2_13b_alibi_mask) = baichuan2_13b_alibi_pattern();

// Phi3-xxx-instruct case
std::shared_ptr<ov::Node> phi3_mask, phi3_offset;
std::tie(phi3_mask, phi3_offset) = handle_phi3_sliding_window();

auto q = pattern::any_input();
auto scale_input = pattern::any_input();

auto k_to_sdpa =
std::make_shared<pattern::op::Or>(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped});
auto v_to_sdpa =
std::make_shared<pattern::op::Or>(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped});

auto mask_to_sdpa = std::make_shared<pattern::op::Or>(
OutputVector{general_alibi_mask, jais_alibi_mask, baichuan2_13b_alibi_mask, pattern::any_input()});
OutputVector{phi3_mask, general_alibi_mask, jais_alibi_mask, baichuan2_13b_alibi_mask, pattern::any_input()});

auto sdpa_with_4_inputs =
pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
Expand All @@ -317,7 +345,6 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
ov::matcher_pass_callback callback = [=,
&kv_parameters,
&model_remaining_params,
&sliding_window,
&parameters_to_remove,
&block_indices_inputs_for_each_layer,
&score_results,
Expand Down Expand Up @@ -492,6 +519,18 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par

OutputVector pa_arguments = {q_reshape, k_reshape, v_reshape, k_parameter, v_parameter};
pa_arguments.insert(pa_arguments.end(), model_remaining_params.begin(), model_remaining_params.end());

std::shared_ptr<Node> sliding_window;
if (pattern_map.count(phi3_offset)) {
auto offset = pattern_map.at(phi3_offset).get_node_shared_ptr();
if (offset->get_element_type() != element::i32) {
offset = std::make_shared<v0::Convert>(offset, element::i32);
}
sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset);
} else {
sliding_window = v0::Constant::create(element::i32, Shape{}, {0});
}

std::initializer_list<std::shared_ptr<Node>> additional_params = {scale,
sliding_window,
alibi_slopes,
Expand Down
Loading
Loading