Skip to content

Commit c1a298d

Browse files
authored
add remote disk support (#2)
* add remote disk support * add simple readme * Improve disk loading speed by using naive open & read/write * change comment
1 parent e6975c1 commit c1a298d

File tree

5 files changed

+326
-9
lines changed

5 files changed

+326
-9
lines changed

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
11
# lmcache-server
2+
## Start lmcache-server
3+
```
4+
python3 -m lmcache_server.server localhost <port> <storage>
5+
<port>: an arbitrary port
6+
<storage>: "" (cpu), "cpu" or "an arbitrary path (disk) (e.g., remote_disk/)"
7+
```

lmcache_server/server.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import torch
55
from io import BytesIO
66
from lmcache.protocol import ClientMetaMessage, ServerMetaMessage, Constants
7+
from lmcache_server.storage_backend import CreateStorageBackend
78

89
class LMCacheServer:
9-
def __init__(self, host, port):
10+
def __init__(self, host, port, device):
1011
self.host = host
1112
self.port = port
12-
self.data_store = {}
13+
#self.data_store = {}
14+
self.data_store = CreateStorageBackend(device)
1315
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1416
self.server_socket.bind((host, port))
1517
self.server_socket.listen()
@@ -36,15 +38,17 @@ def handle_client(self, client_socket):
3638
t0 = time.perf_counter()
3739
s = self.receive_all(client_socket, meta.length)
3840
t1 = time.perf_counter()
39-
self.data_store[meta.key] = s
41+
#self.data_store[meta.key] = s
42+
self.data_store.put(meta.key, s)
4043
t2 = time.perf_counter()
4144
#client_socket.sendall(ServerMetaMessage(Constants.SERVER_SUCCESS, 0).serialize())
4245
#t3 = time.perf_counter()
4346
print(f"Time to receive data: {t1 - t0}, time to store data: {t2 - t1}")
4447

4548
case Constants.CLIENT_GET:
4649
t0 = time.perf_counter()
47-
data_string = self.data_store.get(meta.key, None)
50+
#data_string = self.data_store.get(meta.key, None)
51+
data_string = self.data_store.get(meta.key)
4852
t1 = time.perf_counter()
4953
if data_string is not None:
5054
client_socket.sendall(ServerMetaMessage(Constants.SERVER_SUCCESS, len(data_string)).serialize())
@@ -56,11 +60,12 @@ def handle_client(self, client_socket):
5660
client_socket.sendall(ServerMetaMessage(Constants.SERVER_FAIL, 0).serialize())
5761

5862
case Constants.CLIENT_EXIST:
59-
code = Constants.SERVER_SUCCESS if meta.key in self.data_store else Constants.SERVER_FAIL
63+
#code = Constants.SERVER_SUCCESS if meta.key in self.data_store else Constants.SERVER_FAIL
64+
code = Constants.SERVER_SUCCESS if meta.key in self.data_store.list_keys() else Constants.SERVER_FAIL
6065
client_socket.sendall(ServerMetaMessage(code, 0).serialize())
6166

6267
case Constants.CLIENT_LIST:
63-
keys = list(self.data_store.keys())
68+
keys = list(self.data_store.list_keys())
6469
data = "\n".join(keys).encode()
6570
client_socket.sendall(ServerMetaMessage(Constants.SERVER_SUCCESS, len(data)).serialize())
6671
client_socket.sendall(data)
@@ -80,13 +85,17 @@ def run(self):
8085

8186
if __name__ == "__main__":
8287
import os, sys
83-
if len(sys.argv) != 3:
84-
print(f"Usage: {sys.argv[0]} <host> <port>")
88+
if len(sys.argv) not in [3,4]:
89+
print(f"Usage: {sys.argv[0]} <host> <port> <storage>(default:cpu)")
8590
exit(1)
8691

8792
host = sys.argv[1]
8893
port = int(sys.argv[2])
94+
if len(sys.argv) == 4:
95+
device = sys.argv[3]
96+
else:
97+
device = "cpu"
8998

90-
server = LMCacheServer(host, port)
99+
server = LMCacheServer(host, port, device)
91100
server.run()
92101

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from lmcache_server.storage_backend.abstract_backend import LMSBackendInterface
2+
from lmcache_server.storage_backend.local_backend import LMSLocalBackend, LMSLocalDiskBackend
3+
from lmcache.logging import init_logger
4+
5+
logger = init_logger(__name__)
6+
7+
8+
def CreateStorageBackend(
9+
device: str
10+
) -> LMSBackendInterface:
11+
match device:
12+
case "cpu":
13+
# cpu only
14+
logger.info("Initializing cpu-only cache server")
15+
return LMSLocalBackend()
16+
17+
case _:
18+
# cpu only
19+
logger.info("Initializing disk-only cache server")
20+
return LMSLocalDiskBackend(path=device)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import abc
2+
import torch
3+
from lmcache.logging import init_logger
4+
from typing import Tuple, Optional, Iterator, List
5+
6+
logger = init_logger(__name__)
7+
8+
class LMSBackendInterface(metaclass=abc.ABCMeta):
9+
10+
@abc.abstractmethod
11+
def put(
12+
self,
13+
key: str,
14+
kv_chunk_bytes: bytes,
15+
blocking = True,
16+
) -> None:
17+
"""
18+
Store the KV cache of the tokens into the cache server.
19+
20+
Input:
21+
key: the key of the token chunk, in the format of str
22+
kv_chunk: the kv cache (bytes) of the token chunk, in the format of a big tensor
23+
blocking: whether to block the call before the operation is completed
24+
25+
Returns:
26+
None
27+
28+
Note:
29+
The KV cache should NOT have the "batch" dimension.
30+
"""
31+
raise NotImplementedError
32+
33+
@abc.abstractmethod
34+
def contains(
35+
self,
36+
key: str,
37+
) -> bool:
38+
"""
39+
Query if a key is in the cache or not
40+
"""
41+
raise NotImplementedError
42+
43+
@abc.abstractmethod
44+
def get(
45+
self,
46+
key: str,
47+
) -> Optional[torch.Tensor]:
48+
"""
49+
Retrive the KV cache chunk by the given key
50+
51+
Input:
52+
key: the key of the token chunk, including prefix hash and format
53+
54+
Output:
55+
the kv cache of the token chunk, in the format of a big tensor
56+
None if the key is not found
57+
"""
58+
raise NotImplementedError
59+
60+
@abc.abstractmethod
61+
def list_keys(
62+
self,
63+
) -> List[str]:
64+
"""
65+
Retrive the KV cache chunk by the given key
66+
67+
Input:
68+
key: the key of the token chunk, including prefix hash and format
69+
70+
Output:
71+
the kv cache of the token chunk, in the format of a big tensor
72+
None if the key is not found
73+
"""
74+
raise NotImplementedError
75+
76+
77+
def close(self):
78+
"""
79+
Do the cleanup things
80+
Children classes should override this method if necessary
81+
"""
82+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from typing import Tuple, Optional, Iterator, List
2+
from safetensors import safe_open
3+
from safetensors.torch import save_file
4+
import re
5+
import io
6+
import torch
7+
import redis
8+
import os
9+
import pickle
10+
11+
from lmcache_server.storage_backend.abstract_backend import LMSBackendInterface
12+
from lmcache.logging import init_logger
13+
from lmcache.utils import _lmcache_nvtx_annotate
14+
15+
logger = init_logger(__name__)
16+
17+
class LMSLocalBackend(LMSBackendInterface):
18+
"""
19+
Cache engine for storing the KV cache of the tokens in the local cpu/gpu memory.
20+
"""
21+
def __init__(
22+
self,
23+
):
24+
"""
25+
Throws:
26+
RuntimeError if the loaded configuration does not match the current configuration
27+
"""
28+
super().__init__()
29+
30+
self.dict = {}
31+
32+
def list_keys(
33+
self
34+
) -> List[str]:
35+
36+
return list(self.dict.keys())
37+
38+
def contains(
39+
self,
40+
key: str,
41+
) -> bool:
42+
"""
43+
Check if the cache engine contains the key.
44+
45+
Input:
46+
key: the key of the token chunk, including prefix hash and format
47+
48+
Returns:
49+
True if the cache engine contains the key, False otherwise
50+
"""
51+
return key in self.dict
52+
53+
def put(
54+
self,
55+
key: str,
56+
kv_chunk_bytes: bytes,
57+
blocking: bool = True,
58+
) -> None:
59+
"""
60+
Store the KV cache of the tokens into the cache engine.
61+
62+
Input:
63+
key: the key of the token chunk, including prefix hash and format
64+
kv_chunk: the kv cache of the token chunk, in the format of nested tuples
65+
66+
Returns:
67+
None
68+
69+
Note:
70+
The KV cache should NOT have the "batch" dimension.
71+
"""
72+
if not blocking:
73+
logger.warn("Non-blocking is not implemented for local backend")
74+
self.dict[key] = kv_chunk_bytes
75+
76+
77+
@_lmcache_nvtx_annotate
78+
def get(
79+
self,
80+
key: str,
81+
) -> Optional[bytes]:
82+
"""
83+
Retrive the KV cache chunk by the given key
84+
85+
Input:
86+
key: the key of the token chunk, including prefix hash and format
87+
Output:
88+
the kv cache of the token chunk, in the format of nested tuples
89+
None if the key is not found
90+
"""
91+
return self.dict.get(key, None)
92+
93+
94+
# TODO(Jiayi): need to optimize disk loading
95+
# current impl. with "naive open read/write" might not be efficient (better than torch.load)
96+
class LMSLocalDiskBackend(LMSBackendInterface):
97+
"""
98+
Cache engine for storing the KV cache of the tokens in the local disk.
99+
"""
100+
def __init__(
101+
self,
102+
path: str,
103+
):
104+
"""
105+
Throws:
106+
RuntimeError if the loaded configuration does not match the current configuration
107+
"""
108+
super().__init__()
109+
110+
self.path = path
111+
if not os.path.exists(self.path):
112+
os.makedirs(self.path)
113+
self.filenames = set()
114+
115+
def list_keys(
116+
self
117+
) -> List[str]:
118+
119+
return list(self.filenames)
120+
121+
def contains(
122+
self,
123+
key: str,
124+
) -> bool:
125+
"""
126+
Check if the cache engine contains the key.
127+
128+
Input:
129+
key: the key of the token chunk, including prefix hash and format
130+
131+
Returns:
132+
True if the cache engine contains the key, False otherwise
133+
"""
134+
return key in self.filenames
135+
136+
def _key_to_path(
137+
self,
138+
key: str,
139+
) -> str:
140+
"""
141+
Covert key to path_name
142+
143+
Input:
144+
key: the key of the token chunk, including prefix hash and format
145+
146+
Returns:
147+
returns the path name
148+
"""
149+
return self.path + key.replace("/","-") + ".bin"
150+
151+
152+
def put(
153+
self,
154+
key: str,
155+
kv_chunk_bytes: bytes,
156+
blocking: bool = True,
157+
) -> None:
158+
"""
159+
Store the KV cache of the tokens into the cache engine.
160+
161+
Input:
162+
key: the key of the token chunk, including prefix hash and format
163+
kv_chunk: the kv cache of the token chunk, in the format of nested tuples
164+
165+
Returns:
166+
None
167+
168+
Note:
169+
The KV cache should NOT have the "batch" dimension.
170+
"""
171+
if not blocking:
172+
logger.warn("Non-blocking is not implemented for local backend")
173+
self.filenames.add(key)
174+
logger.info(f"Saving cache to {self._key_to_path(key)}")
175+
#torch.save(kv_chunk_bytes, self._key_to_path(key))
176+
with open(self._key_to_path(key), "wb") as binary_file:
177+
binary_file.write(kv_chunk_bytes)
178+
179+
180+
@_lmcache_nvtx_annotate
181+
def get(
182+
self,
183+
key: str,
184+
) -> Optional[bytes]:
185+
"""
186+
Retrive the KV cache chunk by the given key
187+
188+
Input:
189+
key: the key of the token chunk, including prefix hash and format
190+
Output:
191+
the kv cache of the token chunk, in the format of nested tuples
192+
None if the key is not found
193+
"""
194+
if key not in self.filenames:
195+
return None
196+
197+
with open(self._key_to_path(key), "rb") as binary_file:
198+
return binary_file.read()
199+
200+
#return torch.load(self._key_to_path(key))

0 commit comments

Comments
 (0)