Skip to content

Commit 715f8a7

Browse files
authoredDec 10, 2024
Add ComposedPress and compress method (NVIDIA#29)
1 parent 9ecd556 commit 715f8a7

9 files changed

+202
-167
lines changed
 

‎README.md

+12-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pip install flash-attn --no-build-isolation
1919

2020
## Usage
2121

22-
This repository provides a set of "presses" that compress the KV cache by pruning the least important key-value pairs in each attention head. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that controls the amount of pruning. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
22+
This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
2323

2424

2525

@@ -41,7 +41,7 @@ answer = pipe(context, question=question, press=press)["answer"]
4141
In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example.
4242

4343
> [!IMPORTANT]
44-
> We focus on pruning during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.
44+
> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.
4545
4646
> [!NOTE]
4747
> To use the `ObservedAttentionPress`, use `model_kwargs={"attn_implementation":"eager"}` in order to materialize the attention weights (this method is not compatible with flash attention).
@@ -51,16 +51,23 @@ In the snippet above, the compression is only applied on the context tokens so t
5151
We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.
5252

5353
## Available presses
54-
All current presses are training free. We provide the following presses associated with the following scores:
54+
55+
All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance:
5556

5657
- `RandomPress`: random score
5758
- `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430))
58-
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
5959
- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
6060
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
6161
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
6262
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
63-
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above.
63+
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
64+
65+
We also provide presses relying on a different logic:
66+
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018))
67+
68+
Finally we provide two special presses:
69+
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental)
70+
- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks
6471

6572
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
6673

‎kvpress/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
from kvpress.presses.snapkv_press import SnapKVPress
1414
from kvpress.presses.streaming_llm_press import StreamingLLMPress
1515
from kvpress.presses.think_press import ThinKPress
16+
from kvpress.presses.composed_press import ComposedPress
1617

1718
__all__ = [
1819
"BasePress",
20+
"ComposedPress",
1921
"ScorerPress",
2022
"ExpectedAttentionPress",
2123
"KnormPress",

‎kvpress/presses/base_press.py

+81-7
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,66 @@
99

1010
import torch
1111
from torch import nn
12-
from transformers import LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, Qwen2ForCausalLM
12+
from transformers import (
13+
LlamaForCausalLM,
14+
MistralForCausalLM,
15+
Phi3ForCausalLM,
16+
PreTrainedModel,
17+
Qwen2ForCausalLM,
18+
QuantizedCache,
19+
)
1320

1421
logger = logging.getLogger(__name__)
1522

1623

1724
@dataclass
1825
class BasePress:
1926
"""
20-
Base class for all pruning methods.
21-
The `forward_hook` method is called after the forward pass of an attention layer.
22-
Any pruning/updating method should be implemented in the derived class.
27+
Base class for all KV cache compression methods.
28+
The `forward_hook` method is called after the forward pass of an attention layer to update the cache.
2329
"""
2430

31+
def compress(
32+
self,
33+
module: nn.Module,
34+
hidden_states: torch.Tensor,
35+
keys: torch.Tensor,
36+
values: torch.Tensor,
37+
attentions: torch.Tensor,
38+
kwargs: dict,
39+
) -> tuple[torch.Tensor, torch.Tensor]:
40+
"""
41+
The core logic of the compression method.
42+
43+
Parameters
44+
----------
45+
module :
46+
Transformer layer, see `hook` method for more details
47+
hidden_states :
48+
Hidden states of the layer
49+
keys :
50+
Keys of the cache (unquantized)
51+
values :
52+
Values of the cache (unquantized)
53+
attentions :
54+
Attention weights of the layer
55+
kwargs :
56+
Keyword arguments, as given to the forward pass of the layer
57+
58+
Returns
59+
-------
60+
tuple[torch.Tensor, torch.Tensor]
61+
Updated keys and values
62+
"""
63+
64+
raise NotImplementedError("compress method must be implemented in subclass")
65+
2566
def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
26-
"""Cache compression hook called after the forward pass of an attention layer.
27-
The hook is applied only during the pre-filling phase if there is some pruning ratio.
67+
"""
68+
Default forward hook called after the forward pass of an attention layer.
69+
The hook calls the compress method to compress the KV cache while ensuring:
70+
- compression is only applied only during the pre-filling phase
71+
- KV cache quantization is handled correctly
2872
2973
Parameters
3074
----------
@@ -40,8 +84,38 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
4084
Returns
4185
-------
4286
Modified output of the forward pass of the layer.
87+
4388
"""
44-
raise NotImplementedError("forward_hook method must be implemented in the derived class")
89+
# See e.g. LlamaDecoderLayer.forward for the output structure
90+
if len(output) == 3:
91+
_, attentions, cache = output
92+
else:
93+
attentions, cache = None, output[-1]
94+
95+
hidden_states = kwargs["hidden_states"]
96+
q_len = hidden_states.shape[1]
97+
98+
# Don't compress after pre-filling
99+
if cache.seen_tokens > q_len:
100+
return output
101+
102+
if isinstance(cache, QuantizedCache):
103+
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
104+
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
105+
else:
106+
keys = cache.key_cache[module.layer_idx]
107+
values = cache.value_cache[module.layer_idx]
108+
109+
keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs)
110+
111+
if isinstance(cache, QuantizedCache):
112+
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
113+
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
114+
else:
115+
cache.key_cache[module.layer_idx] = keys
116+
cache.value_cache[module.layer_idx] = values
117+
118+
return output
45119

46120
@contextmanager
47121
def __call__(self, model: PreTrainedModel) -> Generator:

‎kvpress/presses/composed_press.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from dataclasses import dataclass
2+
from kvpress.presses.base_press import BasePress
3+
4+
5+
@dataclass
6+
class ComposedPress(BasePress):
7+
"""
8+
Chain multiple presses together to create a composed press
9+
"""
10+
11+
presses: list[BasePress]
12+
13+
def __post_init__(self):
14+
self.compression_ratio = None
15+
16+
def forward_hook(self, module, input, kwargs, output):
17+
self.compression_ratio = 1.0
18+
for press in self.presses:
19+
output = press.forward_hook(module, input, kwargs, output)
20+
self.compression_ratio *= press.compression_ratio # type: ignore
21+
return output

‎kvpress/presses/scorer_press.py

+24-75
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99
from torch import nn
10-
from transformers import QuantizedCache
1110

1211
from kvpress.presses.base_press import BasePress
1312

@@ -18,8 +17,9 @@
1817
class ScorerPress(BasePress):
1918
"""
2019
Default press method for using a score method.
21-
The `forward_hook` method is called after the forward pass of an attention layer.
22-
and updates the cache with the pruned KV pairs.
20+
Any ScorerPress subclass must implement the `score` method that computes a tensor of scores for each key-value pair
21+
The KV pairs with the lowest scores will be pruned in the `compress` method.
22+
The cache is uniformly pruned across all heads and layers using the compression_ratio parameter.
2323
"""
2424

2525
compression_ratio: float = 0.0
@@ -36,87 +36,36 @@ def score(
3636
attentions: torch.Tensor,
3737
kwargs,
3838
) -> torch.Tensor:
39-
"""Compute the scores for each KV pair in the layer.
40-
41-
Parameters
42-
----------
43-
module :
44-
Transformer layer, see `hook` method for more details.
45-
hidden_states :
46-
Hidden states of the layer.
47-
keys :
48-
Keys of the cache. Note keys are after RoPE.
49-
values :
50-
Values of the cache.
51-
attentions :
52-
Attention weights of the layer.
53-
kwargs :
54-
Keyword arguments, as given to the forward pass of the layer.
55-
56-
Returns
57-
-------
58-
Scores for each KV pair in the layer, shape keys.shape[:-1].
5939
"""
60-
raise NotImplementedError
61-
62-
def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
63-
"""
64-
Default cache compression hook called after the forward pass of an attention layer.
65-
The hook is applied only during the pre-filling phase if there is some pruning ratio.
66-
This implementation allows to remove a constant number of KV pairs.
67-
68-
Parameters
69-
----------
70-
module :
71-
Transformer attention layer.
72-
input :
73-
Input to the hook. This is the input to the forward pass of the layer.
74-
kwargs :
75-
Keyword arguments, as given to the forward pass of the layer.
76-
output :
77-
Output of the hook. This is the original output of the forward pass of the layer.
78-
79-
Returns
80-
-------
81-
Modified output of the forward pass of the layer.
82-
40+
Compute a tensor of scores with shape (bsz, num_key_value_heads, q_len)
41+
The KV pairs with lowest scores will be pruned in the `compress` method.
8342
"""
84-
# See e.g. LlamaDecoderLayer.forward for the output structure
85-
if len(output) == 3:
86-
_, attentions, cache = output
87-
else:
88-
attentions, cache = None, output[-1]
89-
90-
hidden_states = kwargs["hidden_states"]
91-
q_len = hidden_states.shape[1]
43+
raise NotImplementedError
9244

93-
# Don't compress if the compression ratio is 0 or this is not pre-filling
94-
if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
95-
return output
45+
def compress(
46+
self,
47+
module: nn.Module,
48+
hidden_states: torch.Tensor,
49+
keys: torch.Tensor,
50+
values: torch.Tensor,
51+
attentions: torch.Tensor,
52+
kwargs: dict,
53+
) -> tuple[torch.Tensor, torch.Tensor]:
9654

97-
if isinstance(cache, QuantizedCache):
98-
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
99-
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
100-
else:
101-
keys = cache.key_cache[module.layer_idx]
102-
values = cache.value_cache[module.layer_idx]
55+
if self.compression_ratio == 0:
56+
return keys, values
10357

104-
with torch.no_grad():
105-
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
58+
# Compute scores
59+
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
10660

107-
# Prune KV pairs with the lowest scores
61+
# Get indices of KV pairs with the lowest scores
62+
q_len = hidden_states.shape[1]
10863
n_kept = int(q_len * (1 - self.compression_ratio))
10964
indices = scores.topk(n_kept, dim=-1).indices
11065
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
11166

112-
# Update cache
67+
# Prune keys and values
11368
keys = keys.gather(2, indices).contiguous()
11469
values = values.gather(2, indices).contiguous()
115-
if isinstance(cache, QuantizedCache):
116-
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
117-
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
118-
else:
119-
cache.key_cache[module.layer_idx] = keys
120-
cache.value_cache[module.layer_idx] = values
121-
122-
return output
70+
71+
return keys, values

‎kvpress/presses/think_press.py

+18-42
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33

44

55
from dataclasses import dataclass
6-
from typing import Optional
76

87
import torch
98
from torch import nn
10-
from transformers.cache_utils import QuantizedCache
119
from transformers.models.llama.modeling_llama import rotate_half
1210

1311
from kvpress.presses.base_press import BasePress
@@ -18,7 +16,7 @@ class ThinKPress(BasePress):
1816
"""
1917
ThinK (https://arxiv.org/pdf/2407.21018) compresses the dimensions of the keys, and not the sequence length.
2018
Hence it can be combined with any other press that compresses the sequence length, e.g.
21-
press = ThinKPress(compression_ratio=0.5, inner_press=SnapKVPress(compression_ratio=0.5))
19+
press = ComposedPress([SnapKVPress(0.5), ThinKPress(0.5)])
2220
2321
Here, we zero out the pruned dimensions resulting in no memory gain (the shape of the keys remains the same).
2422
To achieve memory savings, several options can be considered (see https://github.com/NVIDIA/kvpress/pull/18/),
@@ -28,7 +26,6 @@ class ThinKPress(BasePress):
2826
"""
2927

3028
key_channel_compression_ratio: float = 0.0
31-
inner_press: Optional[BasePress] = None
3229
window_size: int = 32
3330

3431
def compute_window_queries(self, module, hidden_states):
@@ -55,36 +52,26 @@ def compute_window_queries(self, module, hidden_states):
5552

5653
return query_states
5754

58-
def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
55+
def compress(
56+
self,
57+
module: nn.Module,
58+
hidden_states: torch.Tensor,
59+
keys: torch.Tensor,
60+
values: torch.Tensor,
61+
attentions: torch.Tensor,
62+
kwargs: dict,
63+
) -> tuple[torch.Tensor, torch.Tensor]:
5964
"""
60-
We first apply the inner press, then we prune the key dimensions. If other similar presses are requested,
61-
we will create a dedicated DimensionBasePress class to avoid code duplication.
65+
If other similar presses are requested, we might create a generic compress method for dimension pruning
66+
to avoid code duplication.
6267
"""
6368

64-
# Apply the forward hook of the inner press
65-
if self.inner_press is not None:
66-
output = self.inner_press.forward_hook(module, input, kwargs, output)
69+
if self.key_channel_compression_ratio == 0:
70+
return keys, values
6771

68-
# Don't compress if the compression ratio is 0 or this is not pre-filling
69-
cache = output[-1]
70-
hidden_states = kwargs["hidden_states"]
71-
q_len = hidden_states.shape[1]
72-
assert q_len > self.window_size, "Query length should be greater than the window size"
73-
74-
if (self.key_channel_compression_ratio == 0) or (cache.seen_tokens > q_len):
75-
return output
76-
77-
# Get keys
78-
if isinstance(cache, QuantizedCache):
79-
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
80-
else:
81-
keys = cache.key_cache[module.layer_idx]
72+
# Compute scores per dimension
8273
bsz, num_key_value_heads, q_len, head_dim = keys.shape
83-
84-
# ThinK specific code
8574
queries = self.compute_window_queries(module, kwargs["hidden_states"])
86-
87-
# Compute scores per dimension
8875
queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim)
8976
queries_norm = queries_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, module.head_dim).mean(2)
9077
keys_norm = torch.pow(keys, 2).mean(dim=2)
@@ -96,23 +83,12 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
9683
indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1)
9784
keys = keys.scatter_(-1, indices, 0)
9885

99-
# Update cache
100-
if isinstance(cache, QuantizedCache):
101-
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
102-
else:
103-
cache.key_cache[module.layer_idx] = keys
104-
105-
return output
86+
return keys, values
10687

10788
@property
10889
def compression_ratio(self):
109-
compression_ratio = self.key_channel_compression_ratio / 2
110-
if self.inner_press is not None and hasattr(self.inner_press, "compression_ratio"):
111-
compression_ratio += self.inner_press.compression_ratio
112-
return compression_ratio
90+
return self.key_channel_compression_ratio / 2
11391

11492
@compression_ratio.setter
11593
def compression_ratio(self, value):
116-
raise AttributeError(
117-
"Cannot set the compression ratio of ThinKPress directly. " "Set key_channel_compression_ratio instead."
118-
)
94+
raise AttributeError(f"compression ratio cannot be set for {type(self).__name__}")

‎notebooks/new_press.ipynb

+36-35
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
]
1111
},
1212
{
13-
"metadata": {},
1413
"cell_type": "code",
15-
"outputs": [],
1614
"execution_count": null,
15+
"metadata": {},
16+
"outputs": [],
1717
"source": [
1818
"from dataclasses import dataclass\n",
1919
"\n",
@@ -25,10 +25,10 @@
2525
]
2626
},
2727
{
28-
"metadata": {},
2928
"cell_type": "code",
30-
"outputs": [],
3129
"execution_count": null,
30+
"metadata": {},
31+
"outputs": [],
3232
"source": [
3333
"# Load pipeline\n",
3434
"\n",
@@ -39,10 +39,10 @@
3939
]
4040
},
4141
{
42-
"metadata": {},
4342
"cell_type": "code",
44-
"outputs": [],
4543
"execution_count": null,
44+
"metadata": {},
45+
"outputs": [],
4646
"source": [
4747
"# Load data\n",
4848
"\n",
@@ -62,10 +62,7 @@
6262
"cell_type": "markdown",
6363
"metadata": {},
6464
"source": [
65-
"\n",
66-
"A press registers a forward hook to each attention layer during the pre-filling phase:\n",
67-
"1. Immediately after the forward pass, the hook is called, and it computes a score for each key-value pair using the `press.score` method\n",
68-
"2. The key-value pairs with the lowest scores are then removed based on the `compression_ratio` parameter"
65+
"A press registers a forward hook to each attention layer during the pre-filling phase. Immediately after the forward pass, the hook is called, and it compresses the KV cache."
6966
]
7067
},
7168
{
@@ -127,7 +124,6 @@
127124
"cell_type": "markdown",
128125
"metadata": {},
129126
"source": [
130-
"\n",
131127
"The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair.\n",
132128
"\n",
133129
"The arguments of the `score` method are obtained from the forward hook:\n",
@@ -140,10 +136,10 @@
140136
]
141137
},
142138
{
143-
"metadata": {},
144139
"cell_type": "code",
145-
"outputs": [],
146140
"execution_count": null,
141+
"metadata": {},
142+
"outputs": [],
147143
"source": [
148144
"class MyKnormPress(ScorerPress):\n",
149145
" def score(\n",
@@ -181,47 +177,42 @@
181177
"cell_type": "markdown",
182178
"metadata": {},
183179
"source": [
184-
"### 2.2 Updating the `forward_hook` method "
180+
"### 2.2 Updating the `compress` method "
185181
]
186182
},
187183
{
188184
"cell_type": "markdown",
189185
"metadata": {},
190186
"source": [
191-
"The `forward_hook` method defined in the `BasePress` class roughly works as follows:\n",
192-
"\n",
193-
"1. Get the scores\n",
194-
"2. Update the key-value pairs based on the scores and the `compression_ratio`\n",
187+
"The `compress` method defined in the `BasePress` contains the core logic of the compression and returns compressed keys and values. For instance, in the `ScorerPress` the `compress` calls the `score` method (which is specific to `ScorerPress`) and prune the key-value pairs based on the scores.\n",
195188
"\n",
196-
"While we generally do not recommend to modify this method, the following example will show how it works. We will re-implement the `StreamingLLMPress` without using the `compression_ratio` parameter. In `StreamingLLM`, only the first `n_first` and last `n_last` key-value pairs are kept."
189+
"The following example will show how it works. We will re-implement the `StreamingLLMPress` in a more compact way."
197190
]
198191
},
199192
{
200-
"metadata": {},
201193
"cell_type": "code",
202-
"outputs": [],
203194
"execution_count": null,
195+
"metadata": {},
196+
"outputs": [],
204197
"source": [
205198
"@dataclass\n",
206199
"class MyStreamingLLMPress(BasePress):\n",
207200
" n_first: int = 1\n",
208201
" n_last: int = 8\n",
209202
"\n",
210-
" def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):\n",
211-
"\n",
212-
" # Get the cache (transformers.cache_utils.DynamicCache object)\n",
213-
" cache = output[-1]\n",
214-
" i = module.layer_idx\n",
215-
" keys, values = cache.key_cache[i], cache.value_cache[i]\n",
203+
" def compress(\n",
204+
" self,\n",
205+
" module: nn.Module,\n",
206+
" hidden_states: torch.Tensor,\n",
207+
" keys: torch.Tensor,\n",
208+
" values: torch.Tensor,\n",
209+
" attentions: torch.Tensor,\n",
210+
" kwargs: dict,\n",
211+
" ) -> tuple[torch.Tensor, torch.Tensor]:\n",
216212
"\n",
217-
" # Update the cache to only keep the first and last tokens\n",
218213
" mask = torch.ones(keys.shape[-2], dtype=torch.bool, device=keys.device)\n",
219214
" mask[self.n_first : -self.n_last] = False\n",
220-
" cache.key_cache[i] = keys[:, :, mask, :]\n",
221-
" cache.value_cache[i] = values[:, :, mask, :]\n",
222-
"\n",
223-
" # Return the updated output (output[-1] has been modified in-place)\n",
224-
" return output\n",
215+
" return keys[:, :, mask, :], values[:, :, mask, :]\n",
225216
"\n",
226217
"\n",
227218
"for n_last in [2, 4, 8]:\n",
@@ -231,6 +222,13 @@
231222
" print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")"
232223
]
233224
},
225+
{
226+
"cell_type": "markdown",
227+
"metadata": {},
228+
"source": [
229+
"Note that in the `compress` method is itself used in the `forward_hook` method which ensures quantization is handled properly and that the compression is only performed during prefilling. While we don't recommend to change the `forward_hook` method directly, you can still modify it if you need to !"
230+
]
231+
},
234232
{
235233
"cell_type": "markdown",
236234
"metadata": {},
@@ -242,7 +240,10 @@
242240
"cell_type": "markdown",
243241
"metadata": {},
244242
"source": [
245-
"All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py). We recommend not to update the `forward_hook` or `__call__` method unless necessary."
243+
"All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n",
244+
"- register it in the `__init__.py` file of repository\n",
245+
"- add a test [test_presses.py](tests/presses/test_presses.py)\n",
246+
"- update the README"
246247
]
247248
}
248249
],
@@ -262,7 +263,7 @@
262263
"name": "python",
263264
"nbconvert_exporter": "python",
264265
"pygments_lexer": "ipython3",
265-
"version": "3.10.12"
266+
"version": "3.12.3"
266267
}
267268
},
268269
"nbformat": 4,

‎tests/presses/test_presses.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from transformers import DynamicCache
88

99
from kvpress import (
10+
ComposedPress,
1011
ExpectedAttentionPress,
1112
KnormPress,
1213
ObservedAttentionPress,
@@ -20,9 +21,11 @@
2021
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401
2122

2223

23-
def test_think_inner_press(unit_test_model): # noqa: F811
24-
press = ThinKPress(key_channel_compression_ratio=0.5, window_size=2, inner_press=KnormPress(0.5))
25-
with press(unit_test_model):
24+
def test_composed_press(unit_test_model): # noqa: F811
25+
press1 = KnormPress(compression_ratio=0.5)
26+
press2 = ThinKPress(key_channel_compression_ratio=0.5, window_size=2)
27+
composed_press = ComposedPress([press1, press2])
28+
with composed_press(unit_test_model):
2629
input_ids = unit_test_model.dummy_inputs["input_ids"]
2730
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
2831

@@ -39,6 +42,8 @@ def test_presses_run(unit_test_model): # noqa: F811
3942
with press(unit_test_model):
4043
input_ids = unit_test_model.dummy_inputs["input_ids"]
4144
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
45+
# Check that the press has a compression_ratio attribute
46+
assert hasattr(press, "compression_ratio")
4247

4348

4449
def test_presses_run_observed_attention(unit_test_model_output_attention): # noqa: F811

0 commit comments

Comments
 (0)
Please sign in to comment.