You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This repository provides a set of "presses" that compress the KV cache by pruning the least important key-value pairs in each attention head. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that controls the amount of pruning. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
22
+
This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example.
42
42
43
43
> [!IMPORTANT]
44
-
> We focus on pruning during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.
44
+
> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.
45
45
46
46
> [!NOTE]
47
47
> To use the `ObservedAttentionPress`, use `model_kwargs={"attn_implementation":"eager"}` in order to materialize the attention weights (this method is not compatible with flash attention).
@@ -51,16 +51,23 @@ In the snippet above, the compression is only applied on the context tokens so t
51
51
We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.
52
52
53
53
## Available presses
54
-
All current presses are training free. We provide the following presses associated with the following scores:
54
+
55
+
All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance:
55
56
56
57
-`RandomPress`: random score
57
58
-`KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430))
58
-
-`ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
59
59
-`SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
60
60
-`ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
61
61
-`StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
62
62
-`TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
63
-
-`ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above.
63
+
-`ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
64
+
65
+
We also provide presses relying on a different logic:
66
+
-`ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018))
67
+
68
+
Finally we provide two special presses:
69
+
-`PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental)
70
+
-`ComposedPress`: a press that composes multiple presses together by chaining their forward hooks
64
71
65
72
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
Copy file name to clipboardexpand all lines: notebooks/new_press.ipynb
+36-35
Original file line number
Diff line number
Diff line change
@@ -10,10 +10,10 @@
10
10
]
11
11
},
12
12
{
13
-
"metadata": {},
14
13
"cell_type": "code",
15
-
"outputs": [],
16
14
"execution_count": null,
15
+
"metadata": {},
16
+
"outputs": [],
17
17
"source": [
18
18
"from dataclasses import dataclass\n",
19
19
"\n",
@@ -25,10 +25,10 @@
25
25
]
26
26
},
27
27
{
28
-
"metadata": {},
29
28
"cell_type": "code",
30
-
"outputs": [],
31
29
"execution_count": null,
30
+
"metadata": {},
31
+
"outputs": [],
32
32
"source": [
33
33
"# Load pipeline\n",
34
34
"\n",
@@ -39,10 +39,10 @@
39
39
]
40
40
},
41
41
{
42
-
"metadata": {},
43
42
"cell_type": "code",
44
-
"outputs": [],
45
43
"execution_count": null,
44
+
"metadata": {},
45
+
"outputs": [],
46
46
"source": [
47
47
"# Load data\n",
48
48
"\n",
@@ -62,10 +62,7 @@
62
62
"cell_type": "markdown",
63
63
"metadata": {},
64
64
"source": [
65
-
"\n",
66
-
"A press registers a forward hook to each attention layer during the pre-filling phase:\n",
67
-
"1. Immediately after the forward pass, the hook is called, and it computes a score for each key-value pair using the `press.score` method\n",
68
-
"2. The key-value pairs with the lowest scores are then removed based on the `compression_ratio` parameter"
65
+
"A press registers a forward hook to each attention layer during the pre-filling phase. Immediately after the forward pass, the hook is called, and it compresses the KV cache."
69
66
]
70
67
},
71
68
{
@@ -127,7 +124,6 @@
127
124
"cell_type": "markdown",
128
125
"metadata": {},
129
126
"source": [
130
-
"\n",
131
127
"The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair.\n",
132
128
"\n",
133
129
"The arguments of the `score` method are obtained from the forward hook:\n",
@@ -140,10 +136,10 @@
140
136
]
141
137
},
142
138
{
143
-
"metadata": {},
144
139
"cell_type": "code",
145
-
"outputs": [],
146
140
"execution_count": null,
141
+
"metadata": {},
142
+
"outputs": [],
147
143
"source": [
148
144
"class MyKnormPress(ScorerPress):\n",
149
145
" def score(\n",
@@ -181,47 +177,42 @@
181
177
"cell_type": "markdown",
182
178
"metadata": {},
183
179
"source": [
184
-
"### 2.2 Updating the `forward_hook` method "
180
+
"### 2.2 Updating the `compress` method "
185
181
]
186
182
},
187
183
{
188
184
"cell_type": "markdown",
189
185
"metadata": {},
190
186
"source": [
191
-
"The `forward_hook` method defined in the `BasePress` class roughly works as follows:\n",
192
-
"\n",
193
-
"1. Get the scores\n",
194
-
"2. Update the key-value pairs based on the scores and the `compression_ratio`\n",
187
+
"The `compress` method defined in the `BasePress` contains the core logic of the compression and returns compressed keys and values. For instance, in the `ScorerPress` the `compress` calls the `score` method (which is specific to `ScorerPress`) and prune the key-value pairs based on the scores.\n",
195
188
"\n",
196
-
"While we generally do not recommend to modify this method, the following example will show how it works. We will re-implement the `StreamingLLMPress` without using the `compression_ratio` parameter. In `StreamingLLM`, only the first `n_first` and last `n_last` key-value pairs are kept."
189
+
"The following example will show how it works. We will re-implement the `StreamingLLMPress` in a more compact way."
"Note that in the `compress` method is itself used in the `forward_hook` method which ensures quantization is handled properly and that the compression is only performed during prefilling. While we don't recommend to change the `forward_hook` method directly, you can still modify it if you need to !"
230
+
]
231
+
},
234
232
{
235
233
"cell_type": "markdown",
236
234
"metadata": {},
@@ -242,7 +240,10 @@
242
240
"cell_type": "markdown",
243
241
"metadata": {},
244
242
"source": [
245
-
"All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py). We recommend not to update the `forward_hook` or `__call__` method unless necessary."
243
+
"All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n",
244
+
"- register it in the `__init__.py` file of repository\n",
245
+
"- add a test [test_presses.py](tests/presses/test_presses.py)\n",
0 commit comments