forked from NVIDIA/kvpress
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtova_press.py
53 lines (40 loc) · 1.76 KB
/
tova_press.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import nn
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.snapkv_press import SnapKVPress
@dataclass
class TOVAPress(ScorerPress):
"""
TOVA (https://arxiv.org/abs/2401.06104) use the attention of the last token averaged across heads
to estimate the importance of the previous KV pairs. This press was reviewed by Michael Hassid,
one of the authors of the TOVA paper.
Official implementation can be found here: https://github.com/schwartz-lab-NLP/TOVA/blob/main/src/tova_cache.py
"""
compression_ratio: float = 0.0
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
if attentions is not None:
attn_weights = attentions[..., -1:, :-1]
else:
attn_weights = SnapKVPress.compute_window_attention(
module, hidden_states, keys, 1, kwargs["position_embeddings"]
)
# Average across heads and repeat num_key_value_head times
scores = attn_weights.mean(1)
scores = scores.repeat(1, keys.shape[1], 1)
# Add back the last token. Use max score to make sure the window is not pruned.
# This is a very slight difference from TOVA that don't enforce it, but the
# last attention weight is usually very high so it should not change the results.
scores = F.pad(scores, (0, 1), value=scores.max().item())
return scores