Skip to content

Commit

Permalink
[CPU] Support group beam search (#21983)
Browse files Browse the repository at this point in the history
* support group beam search

* support dyn batch without set_state

* apply review comments

* strides may be incorrect when batch changed

---------

Co-authored-by: Yu Xu <[email protected]>
  • Loading branch information
luo-cheng2021 and yuxu42 authored Jan 8, 2024
1 parent 49231c0 commit ceeafaf
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 12 deletions.
8 changes: 6 additions & 2 deletions src/plugins/intel_cpu/src/memory_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,12 @@ VariableStateKVcache::VariableStateKVcache(
}

ov::SoPtr<ov::ITensor> VariableStateKVcache::get_state() const {
OPENVINO_ASSERT(m_internal_mem && m_hidden_state, "KVState internal memory is not initialized");
OPENVINO_ASSERT(!is_reset_state(), "KVState is undefined after reset");
if (!m_internal_mem || !m_hidden_state || is_reset_state()) {
auto new_desc = to_static(get_external_desc());
auto external_mem = std::make_shared<Memory>(get_engine(), new_desc);
return std::make_shared<Tensor>(external_mem);
}

auto actual_internal_desc = m_internal_mem->getDescWithType<BlockedMemoryDesc>();
auto&& dims = actual_internal_desc->getShape().getStaticDims();

Expand Down
148 changes: 144 additions & 4 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,13 +824,153 @@ void ScaledDotProductAttention::assignState(const std::shared_ptr<VariableStateK
}
}

void ScaledDotProductAttention::resetBeamTablePastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx) {
std::vector<size_t> order = {0, 1, 2, 3};
if (!m_config.config.permute_axes.empty()) {
order = m_config.config.permute_axes;
}
PlainTensor beam_idx, old_beam_table_k;
auto old_hidden_state_k = m_k_state->hidden_state_mem();
beam_idx.reset(mem_beam_idx);

auto inputNumber = getOriginalInputsNumber();
auto&& v_dims = getParentEdgeAt(inputNumber - 1)->getMemory().getStaticDims();
size_t L0 = v_dims.at(order[2]);
auto B_state = v_dims.at(order[0]);
old_beam_table_k.reset(old_hidden_state_k);

PlainTensor cur_k;
PlainTensor cur_v;
cur_k.reset(mem_cur_k);
cur_v.reset(mem_cur_v);
cur_k = cur_k.permute(order);
cur_v = cur_v.permute(order);
auto B = cur_k.size(0);
auto H = cur_k.size(1);
auto L1 = cur_k.size(2);
auto S = cur_k.size(3);
auto reverse = [&order] (const std::vector<size_t>& cur) {
std::vector<size_t> result(cur.size());
for (size_t i = 0; i < cur.size(); i++) {
result[order[i]] = cur[i];
}
return result;
};

// 1. check beam idx if it's valid
auto* table = beam_idx.data<int32_t>();
for (size_t i = 0; i < B; i++) {
OPENVINO_ASSERT(static_cast<size_t>(table[i]) < B_state, "beam_idx[", i, "]=", table[i],
" should less than batch of previous pastkv: ", B_state);
}

// 2. resize pastkv
{
auto shape = {B, H, (L0 + L1) * 2, S};
auto mem_desc = std::make_shared<CpuBlockedMemoryDesc>(m_kvcache_precision,
Shape(reverse(shape)),
shape,
order);
auto new_internal_mem_k = std::make_shared<Memory>(getEngine(), mem_desc);
auto new_internal_mem_v = std::make_shared<Memory>(getEngine(), mem_desc);

PlainTensor new_pastk, new_pastv, old_past_k, old_past_v;
new_pastk.reset(new_internal_mem_k);
new_pastv.reset(new_internal_mem_v);
new_pastk = new_pastk.permute(order);
new_pastv = new_pastv.permute(order);
if (L0 > 0) {
auto old_internal_mem_k = m_k_state->internal_state_mem();
auto old_internal_mem_v = m_v_state->internal_state_mem();
old_past_k.reset(old_internal_mem_k);
old_past_v.reset(old_internal_mem_v);
old_past_k = old_past_k.permute(order);
old_past_v = old_past_v.permute(order);
parallel_for3d(B, H, L0, [&](size_t b, size_t h, size_t m) {
auto idx = static_cast<size_t>(table[b]);
auto b_kv = static_cast<size_t>(old_beam_table_k.at<int32_t>({idx, m}));
memcpy(&new_pastk.at<char>({b, h, m}),
&old_past_k.at<char>({b_kv, h, m}),
S * old_past_k.m_element_size);
memcpy(&new_pastv.at<char>({b, h, m}),
&old_past_v.at<char>({b_kv, h, m}),
S * old_past_v.m_element_size);
});
}

auto new_shape = {B, H, (L0 + L1), S};
mem_desc = std::make_shared<CpuBlockedMemoryDesc>(m_kvcache_precision,
Shape(reverse(new_shape)),
new_shape,
order,
0,
VectorDims{},
mem_desc->getStrides());
new_internal_mem_k->redefineDesc(mem_desc);
new_internal_mem_v->redefineDesc(mem_desc);
attn_memcpy(cur_k, cur_v, new_pastk.slice(2, L0, L0 + L1), new_pastv.slice(2, L0, L0 + L1));

m_k_state->assign_internal_state(new_internal_mem_k);
m_v_state->assign_internal_state(new_internal_mem_v);
m_k_state->assign_internal_state_max_size(B * H * (L0 + L1) * 2 * S);
m_v_state->assign_internal_state_max_size(B * H * (L0 + L1) * 2 * S);
}
// 3. create beam table
{
auto mem_desc = std::make_shared<CpuBlockedMemoryDesc>(ov::element::i32, Shape{B, (L0 + L1) * 2});

auto new_hidden_state_k = std::make_shared<Memory>(getEngine(), mem_desc);
auto new_hidden_state_v = std::make_shared<Memory>(getEngine(), mem_desc);
PlainTensor new_beam_table_k, new_beam_table_v;
new_beam_table_k.reset(new_hidden_state_k);
new_beam_table_v.reset(new_hidden_state_v);

for (size_t b = 0; b < B; b++) {
for (size_t l = 0; l < L0 + L1; l++) {
new_beam_table_k.at<int32_t>({b, l}) = b;
new_beam_table_v.at<int32_t>({b, l}) = b;
}
}

std::vector<size_t> new_shape{B, (L0 + L1)};
mem_desc = std::make_shared<CpuBlockedMemoryDesc>(ov::element::i32,
Shape(new_shape),
new_shape,
VectorDims{0, 1},
0,
VectorDims{},
mem_desc->getStrides());
new_hidden_state_k->redefineDesc(mem_desc);
new_hidden_state_v->redefineDesc(mem_desc);

m_k_state->assign_hidden_state(new_hidden_state_k);
m_v_state->assign_hidden_state(new_hidden_state_v);
m_k_state->assign_hidden_state_max_size(B * (L0 + L1) * 2);
m_v_state->assign_hidden_state_max_size(B * (L0 + L1) * 2);
}
}

void ScaledDotProductAttention::gatherConcatPastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx) {
PlainTensor cur_k;
cur_k.reset(mem_cur_k);
if (!m_config.config.permute_axes.empty())
auto inputNumber = getOriginalInputsNumber();
auto&& v_dims = getParentEdgeAt(inputNumber - 1)->getMemory().getStaticDims();
size_t B_state;
if (!m_config.config.permute_axes.empty()) {
cur_k = cur_k.permute(m_config.config.permute_axes);
B_state = v_dims.at(m_config.config.permute_axes[0]);
} else {
B_state = v_dims.at(0);
}

auto B = cur_k.size(0);
auto L1 = cur_k.size(2);
if (B != B_state) {
resetBeamTablePastkv(mem_cur_k, mem_cur_v, mem_beam_idx);
return;
}

updateBeamTable(mem_beam_idx, cur_k.size(2));
updateBeamTable(mem_beam_idx, L1);
updatePastkv(mem_cur_k, mem_cur_v);
}

Expand Down Expand Up @@ -858,7 +998,7 @@ void ScaledDotProductAttention::updateBeamTable(const MemoryPtr& mem_beam_idx, s
OPENVINO_ASSERT(B == B_state, "beam idx batch: ", B, " is not equal to batch of state: ", B_state);
OPENVINO_ASSERT(B * (L0 + L1) > 0, "B or (L0+L1) is zero, B: ", B, ", L0: ", L0, ", L1: ", L1);
// resize buffer
if (B * (L0 + L1) > m_k_state->hidden_state_max_size()) {
if (is_reset || B * (L0 + L1) > m_k_state->hidden_state_max_size()) {
auto mem_desc = std::make_shared<CpuBlockedMemoryDesc>(ov::element::i32, Shape{B, (L0 + L1) * 2});

auto new_hidden_state_k = std::make_shared<Memory>(getEngine(), mem_desc);
Expand Down Expand Up @@ -981,7 +1121,7 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M
OPENVINO_ASSERT(B == B_state, "pastkv batch: ", B, " is not equal to batch of state: ", B_state);
OPENVINO_ASSERT(B * (L0 + L1) > 0, "B or (L0+L1) is zero, B: ", B, ", L0: ", L0, ", L1: ", L1);
// resize buffer
if (B * H * (L0 + L1) * S > m_k_state->internal_state_max_size()) {
if (is_reset || B * H * (L0 + L1) * S > m_k_state->internal_state_max_size()) {
auto new_shape = {B, H, (L0 + L1) * 2, S};
auto mem_desc = std::make_shared<CpuBlockedMemoryDesc>(m_kvcache_precision,
Shape(reverse(new_shape)),
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ScaledDotProductAttention : public Node {
void updateBeamTable(const MemoryPtr& mem_beam_idx, size_t new_q_len);
void updatePastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v);
ov::element::Type getRuntimePrecision() const override;
void resetBeamTablePastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx);

struct Config {
ScaledDotProductAttentionWithKVCache::Config config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,8 @@ void ov::intel_cpu::ScaledDotProductAttentionWithKVCache::validate_and_infer_typ
"shape not compatiable at index ",
i);
}
} else if (i == length_index) {
continue;
} else {
NODE_VALIDATION_CHECK(this,
q_ps[i].compatible(past_kv_ps[i]),
"shape not compatiable at index ",
i);
continue;
}
}
past_kv_ps[length_index] += q_ps[length_index];
Expand Down
Loading

0 comments on commit ceeafaf

Please sign in to comment.