Skip to content

Commit cd44d63

Browse files
authoredMar 17, 2025
feat: expose Python binding for KVEventPublisher. Use event pub/sub trait for KV events (ai-dynamo#169)
1 parent dcd83c0 commit cd44d63

File tree

17 files changed

+471
-275
lines changed

17 files changed

+471
-275
lines changed
 

‎lib/bindings/c/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ fn dynamo_create_kv_publisher(
148148
{
149149
Ok(drt) => {
150150
let backend = drt.namespace(namespace)?.component(component)?;
151-
KvEventPublisher::new(drt.clone(), backend, worker_id, kv_block_size)
151+
KvEventPublisher::new(backend, worker_id, kv_block_size)
152152
}
153153
Err(e) => Err(e),
154154
}

‎lib/bindings/python/rust/lib.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
7777
m.add_class::<llm::kv::EndpointKvMetrics>()?;
7878
m.add_class::<llm::kv::AggregatedMetrics>()?;
7979
m.add_class::<llm::kv::KvMetricsAggregator>()?;
80+
m.add_class::<llm::kv::KvEventPublisher>()?;
8081
m.add_class::<http::HttpService>()?;
8182
m.add_class::<http::HttpError>()?;
8283
m.add_class::<http::HttpAsyncEngine>()?;
@@ -227,10 +228,6 @@ impl Component {
227228
Ok(())
228229
})
229230
}
230-
231-
fn event_subject(&self, name: String) -> String {
232-
self.inner.event_subject(name)
233-
}
234231
}
235232

236233
#[pymethods]

‎lib/bindings/python/rust/llm/kv.rs

+114-7
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ use std::collections::HashMap;
1717

1818
use super::*;
1919
use llm_rs::kv_router::indexer::KvIndexerInterface;
20+
use rs::traits::events::EventSubscriber;
2021
use tracing;
2122

23+
use llm_rs::kv_router::{indexer::compute_block_hash_for_seq, protocols::*};
24+
2225
#[pyclass]
2326
pub(crate) struct KvRouter {
2427
inner: Arc<llm_rs::kv_router::KvRouter>,
@@ -119,6 +122,114 @@ impl KvMetricsPublisher {
119122
}
120123
}
121124

125+
#[pyclass]
126+
pub(crate) struct KvEventPublisher {
127+
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
128+
warning_count: u32,
129+
}
130+
131+
#[pymethods]
132+
impl KvEventPublisher {
133+
#[new]
134+
fn new(component: Component, worker_id: i64, kv_block_size: usize) -> PyResult<Self> {
135+
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
136+
component.inner.clone(),
137+
worker_id,
138+
kv_block_size,
139+
)
140+
.map_err(to_pyerr)?;
141+
Ok(Self {
142+
inner: inner.into(),
143+
warning_count: 0,
144+
})
145+
}
146+
147+
#[allow(clippy::too_many_arguments)]
148+
#[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None))]
149+
fn publish_stored(
150+
&mut self,
151+
_py: Python,
152+
event_id: u64,
153+
token_ids: Vec<u32>,
154+
num_block_tokens: Vec<u64>,
155+
block_hashes: Vec<u64>,
156+
lora_id: u64,
157+
parent_hash: Option<u64>,
158+
) -> PyResult<()> {
159+
let event = KvCacheEvent {
160+
event_id,
161+
data: KvCacheEventData::Stored(KvCacheStoreData {
162+
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
163+
blocks: self.create_stored_blocks(
164+
&token_ids,
165+
&num_block_tokens,
166+
&block_hashes,
167+
lora_id,
168+
),
169+
}),
170+
};
171+
172+
self.inner.publish(event).map_err(to_pyerr)
173+
}
174+
175+
fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec<u64>) -> PyResult<()> {
176+
let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
177+
.iter()
178+
.map(|&v| ExternalSequenceBlockHash(v))
179+
.collect();
180+
let event = KvCacheEvent {
181+
event_id,
182+
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
183+
};
184+
185+
self.inner.publish(event).map_err(to_pyerr)
186+
}
187+
}
188+
189+
impl KvEventPublisher {
190+
fn create_stored_block_from_parts(
191+
&self,
192+
block_hash: u64,
193+
token_ids: &[u32],
194+
_lora_id: u64,
195+
) -> KvCacheStoredBlockData {
196+
let tokens_hash = compute_block_hash_for_seq(token_ids, self.inner.kv_block_size())[0];
197+
KvCacheStoredBlockData {
198+
block_hash: ExternalSequenceBlockHash(block_hash),
199+
tokens_hash,
200+
}
201+
}
202+
203+
fn create_stored_blocks(
204+
&mut self,
205+
token_ids: &[u32],
206+
num_block_tokens: &[u64],
207+
block_hashes: &[u64],
208+
lora_id: u64,
209+
) -> Vec<KvCacheStoredBlockData> {
210+
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
211+
212+
let mut token_offset: usize = 0;
213+
for (num_tokens_it, block_hash_it) in num_block_tokens.iter().zip(block_hashes.iter()) {
214+
if (self.warning_count < 3) && (*num_tokens_it != self.inner.kv_block_size() as u64) {
215+
tracing::warn!(
216+
"Block not published. Block size must be {} tokens to be published. Block size is: {}",
217+
self.inner.kv_block_size(),
218+
*num_tokens_it
219+
);
220+
self.warning_count += 1;
221+
break;
222+
}
223+
224+
let tokens = &token_ids[token_offset..(token_offset + *num_tokens_it as usize)];
225+
blocks.push(self.create_stored_block_from_parts(*block_hash_it, tokens, lora_id));
226+
token_offset += *num_tokens_it as usize;
227+
}
228+
229+
blocks
230+
}
231+
}
232+
122233
#[pyclass]
123234
#[derive(Clone)]
124235
pub(crate) struct OverlapScores {
@@ -149,21 +260,17 @@ impl KvIndexer {
149260
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
150261
let runtime = pyo3_async_runtimes::tokio::get_runtime();
151262
runtime.block_on(async {
152-
let kv_subject = component
153-
.inner
154-
.event_subject(llm_rs::kv_router::KV_EVENT_SUBJECT);
155263
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
156264
llm_rs::kv_router::indexer::KvIndexer::new(
157265
component.inner.drt().runtime().child_token(),
158266
kv_block_size,
159267
)
160268
.into();
269+
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
270+
// error checking below will be different.
161271
let mut kv_events_rx = component
162272
.inner
163-
.drt()
164-
.nats_client()
165-
.client()
166-
.subscribe(kv_subject)
273+
.subscribe(llm_rs::kv_router::KV_EVENT_SUBJECT)
167274
.await
168275
.map_err(to_pyerr)?;
169276
let kv_events_tx = inner.event_sender();

‎lib/bindings/python/src/dynamo/_core.pyi

+24-5
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,6 @@ class Component:
102102
"""
103103
...
104104

105-
def event_subject(self, name: str) -> str:
106-
"""
107-
Create an event subject
108-
"""
109-
...
110105

111106
class Endpoint:
112107
"""
@@ -354,6 +349,30 @@ class KvMetricsAggregator:
354349
"""
355350
...
356351

352+
class KvEventPublisher:
353+
"""
354+
A KV event publisher will publish KV events corresponding to the component.
355+
"""
356+
357+
...
358+
359+
def __init__(self, component: Component, worker_id: int, kv_block_size: int) -> None:
360+
"""
361+
Create a `KvEventPublisher` object
362+
"""
363+
364+
def publish_stored(self, event_id, int, token_ids: List[int], num_block_tokens: List[int], block_hashes: List[int], lora_id: int, parent_hash: Optional[int] = None) -> None:
365+
"""
366+
Publish a KV stored event.
367+
"""
368+
...
369+
370+
def publish_removed(self, event_id, int, block_hashes: List[int]) -> None:
371+
"""
372+
Publish a KV removed event.
373+
"""
374+
...
375+
357376
class HttpService:
358377
"""
359378
A HTTP service for dynamo applications.

‎lib/bindings/python/src/dynamo/llm/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
1919
from dynamo._core import HttpError as HttpError
2020
from dynamo._core import HttpService as HttpService
21+
from dynamo._core import KvEventPublisher as KvEventPublisher
2122
from dynamo._core import KvIndexer as KvIndexer
2223
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
2324
from dynamo._core import KvMetricsPublisher as KvMetricsPublisher

‎lib/bindings/python/src/dynamo/runtime/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
# import * causes "unable to detect undefined names"
2525
from dynamo._core import Backend as Backend
2626
from dynamo._core import Client as Client
27+
from dynamo._core import Component as Component
2728
from dynamo._core import DistributedRuntime as DistributedRuntime
28-
from dynamo._core import KvRouter as KvRouter
2929
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
3030
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
3131

‎lib/bindings/python/tests/test_kv_bindings.py

+44-7
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@
2424

2525
import pytest
2626

27-
from dynamo.llm import KvIndexer, KvMetricsAggregator, KvMetricsPublisher
28-
from dynamo.runtime import DistributedRuntime
27+
from dynamo.llm import (
28+
KvEventPublisher,
29+
KvIndexer,
30+
KvMetricsAggregator,
31+
KvMetricsPublisher,
32+
)
33+
from dynamo.runtime import Component, DistributedRuntime
2934

3035
pytestmark = pytest.mark.pre_merge
3136

@@ -64,14 +69,14 @@ async def test_event_handler(distributed_runtime):
6469
kv_block_size = 32
6570
namespace = "kv_test"
6671
component = "event"
72+
kv_listener = distributed_runtime.namespace(namespace).component(component)
73+
await kv_listener.create_service()
6774

6875
# publisher
6976
worker_id = 233
70-
event_publisher = EventPublisher(namespace, component, worker_id, kv_block_size)
77+
event_publisher = EventPublisher(kv_listener, worker_id, kv_block_size)
7178

7279
# indexer
73-
kv_listener = distributed_runtime.namespace(namespace).component(component)
74-
await kv_listener.create_service()
7580
indexer = KvIndexer(kv_listener, kv_block_size)
7681

7782
test_token = [3] * kv_block_size
@@ -93,16 +98,48 @@ async def test_event_handler(distributed_runtime):
9398
scores = await indexer.find_matches_for_request(test_token, lora_id)
9499
assert not scores.scores
95100

96-
event_publisher.shutdown()
101+
102+
class EventPublisher:
103+
def __init__(self, component: Component, worker_id: int, kv_block_size: int):
104+
self.publisher = KvEventPublisher(component, worker_id, kv_block_size)
105+
self.event_id_counter = 0
106+
self.block_hashes: List[int] = []
107+
108+
def store_event(self, tokens, lora_id):
109+
parent_hash = self.event_id_counter if self.event_id_counter > 0 else None
110+
self.publisher.publish_stored(
111+
self.event_id_counter, # event_id
112+
tokens, # token_ids
113+
[
114+
len(tokens),
115+
], # num_block_tokens
116+
[
117+
self.event_id_counter,
118+
], # block_hashes
119+
lora_id, # lora_id
120+
parent_hash, # parent_hash
121+
)
122+
self.block_hashes.append(self.event_id_counter)
123+
self.event_id_counter += 1
124+
125+
def remove_event(self):
126+
self.publisher.publish_removed(
127+
self.event_id_counter, # event_id
128+
[
129+
self.block_hashes[-1],
130+
], # block_hashes
131+
)
132+
self.event_id_counter += 1
97133

98134

135+
# [TODO] to be deprecated
99136
# KV events
100137
class DynamoResult:
101138
OK = 0
102139
ERR = 1
103140

104141

105-
class EventPublisher:
142+
class CtypesEventPublisher:
106143
def __init__(
107144
self, namespace: str, component: str, worker_id: int, kv_block_size: int
108145
):

‎lib/llm/src/kv_router.rs

+20-28
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,18 @@ pub mod scoring;
2929

3030
use crate::kv_router::{
3131
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
32-
metrics_aggregator::collect_endpoints,
32+
metrics_aggregator::collect_endpoints_task,
3333
scheduler::KvScheduler,
3434
scoring::ProcessedEndpoints,
3535
};
3636

37-
// this should be discovered from the backend
37+
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
38+
39+
// [gluo TODO] shouldn't need to be public
40+
// this should be discovered from the component
3841
pub const KV_EVENT_SUBJECT: &str = "kv_events";
3942
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
43+
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
4044

4145
pub struct KvRouter {
4246
// properties of request plane
@@ -54,49 +58,37 @@ pub struct KvRouter {
5458
impl KvRouter {
5559
pub async fn from_runtime(
5660
runtime: DistributedRuntime,
57-
backend: Component,
61+
component: Component,
5862
kv_block_size: usize,
5963
) -> Result<Arc<Self>> {
60-
let nats_client = runtime.nats_client();
61-
let service_name = backend.service_name();
62-
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
63-
let namespace = runtime.namespace(backend.namespace())?;
64-
65-
tracing::info!("Component Namespace {}", backend.namespace());
66-
tracing::info!("Component Service Name {}", service_name);
67-
tracing::info!("KV Subject {}", kv_subject);
68-
Self::new(
69-
nats_client,
70-
service_name,
71-
kv_subject,
72-
namespace,
73-
kv_block_size,
74-
)
75-
.await
64+
let namespace = runtime.namespace(component.namespace().name())?;
65+
66+
tracing::info!("Component Namespace {}", component.namespace());
67+
tracing::info!("Component Service Name {}", component.service_name());
68+
tracing::info!("KV Subject {}.{}", component.subject(), KV_EVENT_SUBJECT);
69+
Self::new(component, namespace, kv_block_size).await
7670
}
7771

7872
pub async fn new(
79-
nats_client: dynamo_runtime::transports::nats::Client,
80-
service_name: String,
81-
kv_subject: String,
73+
component: Component,
8274
namespace: Namespace,
8375
kv_block_size: usize,
8476
) -> Result<Arc<Self>> {
8577
let cancellation_token = CancellationToken::new();
8678
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
8779

88-
tokio::spawn(collect_endpoints(
89-
nats_client.clone(),
90-
service_name.clone(),
80+
tokio::spawn(collect_endpoints_task(
81+
component.clone(),
9182
ep_tx,
9283
cancellation_token.clone(),
9384
));
9485

9586
let indexer = KvIndexer::new(cancellation_token.clone(), kv_block_size);
9687
let scheduler = KvScheduler::start(ep_rx, namespace, kv_block_size).await?;
9788

98-
tracing::debug!("subscribing to kv events: {}", kv_subject);
99-
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
89+
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
90+
// error checking below will be different.
91+
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
10092
let kv_events_tx = indexer.event_sender();
10193

10294
tokio::spawn(async move {
@@ -120,7 +112,7 @@ impl KvRouter {
120112
});
121113

122114
Ok(Arc::new(Self {
123-
service_name,
115+
service_name: component.service_name(),
124116
cancellation_token,
125117
scheduler,
126118
indexer,

‎lib/llm/src/kv_router/metrics_aggregator.rs

+59-38
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
use std::sync::{Arc, Mutex};
1717

1818
pub use crate::kv_router::protocols::ForwardPassMetrics;
19+
use crate::kv_router::KV_METRICS_ENDPOINT;
1920

20-
use crate::kv_router::scheduler::{Endpoint, Service};
21+
use crate::kv_router::scheduler::Endpoint;
2122
use crate::kv_router::ProcessedEndpoints;
2223
use dynamo_runtime::component::Component;
23-
use std::time::Duration;
24+
use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};
2425
use tokio_util::sync::CancellationToken;
2526

2627
pub struct KvMetricsAggregator {
@@ -32,9 +33,8 @@ impl KvMetricsAggregator {
3233
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
3334
let (ep_tx, mut ep_rx) = tokio::sync::mpsc::channel(128);
3435

35-
tokio::spawn(collect_endpoints(
36-
component.drt().nats_client().clone(),
37-
component.service_name(),
36+
tokio::spawn(collect_endpoints_task(
37+
component.clone(),
3838
ep_tx,
3939
cancellation_token.clone(),
4040
));
@@ -80,13 +80,41 @@ impl KvMetricsAggregator {
8080
}
8181
}
8282

83+
/// [gluo TODO] 'collect_endpoints' is from component/metrics,
84+
/// should consolidate these functions into generic metrics aggregator
85+
/// functions and shared by KvMetricsAggregator and component/metrics.
86+
/// Collect endpoints from a component
8387
pub async fn collect_endpoints(
84-
nats_client: dynamo_runtime::transports::nats::Client,
85-
service_name: String,
88+
component: &Component,
89+
subject: &str,
90+
timeout: Duration,
91+
) -> Result<Vec<EndpointInfo>> {
92+
// Collect stats from each backend
93+
let stream = component.scrape_stats(timeout).await?;
94+
95+
// Filter the stats by the service subject
96+
let endpoints = stream
97+
.into_endpoints()
98+
.filter(|e| e.subject.starts_with(subject))
99+
.collect::<Vec<_>>();
100+
tracing::debug!("Endpoints: {endpoints:?}");
101+
102+
if endpoints.is_empty() {
103+
tracing::warn!("No endpoints found matching subject {subject}");
104+
}
105+
106+
Ok(endpoints)
107+
}
108+
109+
pub async fn collect_endpoints_task(
110+
component: Component,
86111
ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
87112
cancel: CancellationToken,
88113
) {
89114
let backoff_delay = Duration::from_millis(100);
115+
let scrape_timeout = Duration::from_millis(300);
116+
let endpoint = component.endpoint(KV_METRICS_ENDPOINT);
117+
let service_subject = endpoint.subject();
90118

91119
loop {
92120
tokio::select! {
@@ -95,48 +123,41 @@ pub async fn collect_endpoints(
95123
break;
96124
}
97125
_ = tokio::time::sleep(backoff_delay) => {
98-
tracing::trace!("collecting endpoints for service: {}", service_name);
99-
let values = match nats_client
100-
.get_endpoints(&service_name, Duration::from_millis(300))
101-
.await
102-
{
103-
Ok(v) => v,
104-
Err(e) => {
105-
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
106-
continue;
107-
}
108-
};
109-
110-
tracing::debug!("values: {:?}", values);
111-
let services: Vec<Service> = values
112-
.into_iter()
113-
.filter(|v| !v.is_empty())
114-
.filter_map(|v| match serde_json::from_slice::<Service>(&v) {
115-
Ok(service) => Some(service),
126+
tracing::trace!("collecting endpoints for service: {}", service_subject);
127+
let unfiltered_endpoints =
128+
match collect_endpoints(&component, &service_subject, scrape_timeout).await
129+
{
130+
Ok(v) => v,
116131
Err(e) => {
117-
tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
118-
None
132+
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_subject, e);
133+
continue;
119134
}
120-
})
121-
.collect();
122-
tracing::debug!("services: {:?}", services);
135+
};
136+
tracing::debug!("unfiltered endpoints: {:?}", unfiltered_endpoints);
123137

124-
let endpoints: Vec<Endpoint> = services
138+
let endpoints: Vec<Endpoint> = unfiltered_endpoints
125139
.into_iter()
126-
.flat_map(|s| s.endpoints)
127140
.filter(|s| s.data.is_some())
128-
.map(|s| Endpoint {
129-
name: s.name,
130-
subject: s.subject,
131-
data: s.data.unwrap(),
132-
})
141+
.filter_map(|s|
142+
match s.data.unwrap().decode::<ForwardPassMetrics>() {
143+
Ok(data) => Some(Endpoint {
144+
name: s.name,
145+
subject: s.subject,
146+
data,
147+
}),
148+
Err(e) => {
149+
tracing::debug!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
150+
None
151+
}
152+
}
153+
)
133154
.collect();
134155
tracing::debug!("endpoints: {:?}", endpoints);
135156

136157
tracing::trace!(
137158
"found {} endpoints for service: {}",
138159
endpoints.len(),
139-
service_name
160+
service_subject
140161
);
141162

142163
let processed = ProcessedEndpoints::new(endpoints);

‎lib/llm/src/kv_router/publisher.rs

+12-19
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
// See the License for the specific language governing permissions and
1414
// limitations under the License.
1515

16-
use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT};
16+
use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT};
1717
use async_trait::async_trait;
18+
use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider};
1819
use dynamo_runtime::{
1920
component::Component,
2021
pipeline::{
2122
network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream,
2223
SingleIn,
2324
},
2425
protocols::annotated::Annotated,
25-
DistributedRuntime, Error, Result,
26+
Error, Result,
2627
};
2728
use futures::stream;
2829
use std::sync::Arc;
@@ -35,16 +36,11 @@ pub struct KvEventPublisher {
3536
}
3637

3738
impl KvEventPublisher {
38-
pub fn new(
39-
drt: DistributedRuntime,
40-
backend: Component,
41-
worker_id: i64,
42-
kv_block_size: usize,
43-
) -> Result<Self> {
39+
pub fn new(component: Component, worker_id: i64, kv_block_size: usize) -> Result<Self> {
4440
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
4541
let p = KvEventPublisher { tx, kv_block_size };
4642

47-
start_publish_task(drt, backend, worker_id, rx);
43+
start_publish_task(component, worker_id, rx);
4844
Ok(p)
4945
}
5046

@@ -59,21 +55,18 @@ impl KvEventPublisher {
5955
}
6056

6157
fn start_publish_task(
62-
drt: DistributedRuntime,
63-
backend: Component,
58+
component: Component,
6459
worker_id: i64,
6560
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
6661
) {
67-
let client = drt.nats_client().client().clone();
68-
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
69-
log::info!("Publishing KV Events to subject: {}", kv_subject);
62+
let component_clone = component.clone();
63+
log::info!("Publishing KV Events to subject: {}", KV_EVENT_SUBJECT);
7064

71-
_ = drt.runtime().secondary().spawn(async move {
65+
_ = component.drt().runtime().secondary().spawn(async move {
7266
while let Some(event) = rx.recv().await {
7367
let router_event = RouterEvent::new(worker_id, event);
74-
let data = serde_json::to_string(&router_event).unwrap();
75-
client
76-
.publish(kv_subject.to_string(), data.into())
68+
component_clone
69+
.publish(KV_EVENT_SUBJECT, &router_event)
7770
.await
7871
.unwrap();
7972
}
@@ -105,7 +98,7 @@ impl KvMetricsPublisher {
10598
let handler = Ingress::for_engine(handler)?;
10699

107100
component
108-
.endpoint("load_metrics")
101+
.endpoint(KV_METRICS_ENDPOINT)
109102
.endpoint_builder()
110103
.stats_handler(move |_| {
111104
let metrics = metrics_rx.borrow_and_update().clone();

‎lib/llm/src/kv_router/scheduler.rs

+2-25
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,8 @@ pub enum KvSchedulerError {
4343
SubscriberShutdown,
4444
}
4545

46-
#[derive(Debug, Clone, Serialize, Deserialize)]
47-
pub struct FlexibleEndpoint {
48-
pub name: String,
49-
pub subject: String,
50-
pub data: Option<ForwardPassMetrics>,
51-
}
52-
46+
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
47+
/// is cleaned (not optional)
5348
#[derive(Debug, Clone, Serialize, Deserialize)]
5449
pub struct Endpoint {
5550
pub name: String,
@@ -72,24 +67,6 @@ impl Endpoint {
7267
}
7368
}
7469

75-
#[derive(Debug, Clone, Serialize, Deserialize)]
76-
pub struct FlexibleService {
77-
pub name: String,
78-
pub id: String,
79-
pub version: String,
80-
pub started: String,
81-
pub endpoints: Vec<FlexibleEndpoint>,
82-
}
83-
84-
#[derive(Debug, Clone, Serialize, Deserialize)]
85-
pub struct Service {
86-
pub name: String,
87-
pub id: String,
88-
pub version: String,
89-
pub started: String,
90-
pub endpoints: Vec<FlexibleEndpoint>,
91-
}
92-
9370
pub struct SchedulingRequest {
9471
isl_tokens: usize,
9572
overlap: OverlapScores,

‎lib/runtime/src/component.rs

+15-16
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ use std::{collections::HashMap, sync::Arc};
6161
use validator::{Validate, ValidationError};
6262

6363
mod client;
64+
#[allow(clippy::module_inception)]
65+
mod component;
6466
mod endpoint;
6567
mod namespace;
6668
mod registry;
@@ -115,12 +117,12 @@ pub struct Component {
115117
// todo - restrict the namespace to a-z0-9-_A-Z
116118
/// Namespace
117119
#[builder(setter(into))]
118-
namespace: String,
120+
namespace: Namespace,
119121
}
120122

121123
impl std::fmt::Display for Component {
122124
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123-
write!(f, "{}.{}", self.namespace, self.name)
125+
write!(f, "{}.{}", self.namespace.name(), self.name)
124126
}
125127
}
126128

@@ -138,30 +140,21 @@ impl RuntimeProvider for Component {
138140

139141
impl Component {
140142
pub fn etcd_path(&self) -> String {
141-
format!("{}/components/{}", self.namespace, self.name)
143+
format!("{}/components/{}", self.namespace.name(), self.name)
142144
}
143145

144146
pub fn service_name(&self) -> String {
145-
Slug::from_string(format!("{}|{}", self.namespace, self.name)).to_string()
146-
}
147-
148-
// todo - move to EventPlane
149-
pub fn event_subject(&self, name: impl AsRef<str>) -> String {
150-
format!("{}.events.{}", self.service_name(), name.as_ref())
147+
Slug::from_string(format!("{}|{}", self.namespace.name(), self.name)).to_string()
151148
}
152149

153150
pub fn path(&self) -> String {
154-
format!("{}/{}", self.namespace, self.name)
151+
format!("{}/{}", self.namespace.name(), self.name)
155152
}
156153

157-
pub fn namespace(&self) -> &str {
154+
pub fn namespace(&self) -> &Namespace {
158155
&self.namespace
159156
}
160157

161-
pub fn drt(&self) -> &DistributedRuntime {
162-
&self.drt
163-
}
164-
165158
pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint {
166159
Endpoint {
167160
component: self.clone(),
@@ -300,6 +293,12 @@ impl RuntimeProvider for Namespace {
300293
}
301294
}
302295

296+
impl std::fmt::Display for Namespace {
297+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298+
write!(f, "{}", self.name)
299+
}
300+
}
301+
303302
impl Namespace {
304303
pub(crate) fn new(runtime: DistributedRuntime, name: String) -> Result<Self> {
305304
Ok(NamespaceBuilder::default()
@@ -312,7 +311,7 @@ impl Namespace {
312311
pub fn component(&self, name: impl Into<String>) -> Result<Component> {
313312
Ok(ComponentBuilder::from_runtime(self.runtime.clone())
314313
.name(name)
315-
.namespace(self.name.clone())
314+
.namespace(self.clone())
316315
.build()?)
317316
}
318317

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
use anyhow::Context;
17+
use async_trait::async_trait;
18+
use futures::stream::StreamExt;
19+
use futures::{Stream, TryStreamExt};
20+
21+
use super::*;
22+
23+
use crate::traits::events::{EventPublisher, EventSubscriber};
24+
25+
#[async_trait]
26+
impl EventPublisher for Component {
27+
fn subject(&self) -> String {
28+
format!("namespace.{}.component.{}", self.namespace.name, self.name)
29+
}
30+
31+
async fn publish(
32+
&self,
33+
event_name: impl AsRef<str> + Send + Sync,
34+
event: &(impl Serialize + Send + Sync),
35+
) -> Result<()> {
36+
let bytes = serde_json::to_vec(event)?;
37+
self.publish_bytes(event_name, bytes).await
38+
}
39+
40+
async fn publish_bytes(
41+
&self,
42+
event_name: impl AsRef<str> + Send + Sync,
43+
bytes: Vec<u8>,
44+
) -> Result<()> {
45+
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
46+
Ok(self
47+
.drt()
48+
.nats_client()
49+
.client()
50+
.publish(subject, bytes.into())
51+
.await?)
52+
}
53+
}
54+
55+
#[async_trait]
56+
impl EventSubscriber for Component {
57+
async fn subscribe(
58+
&self,
59+
event_name: impl AsRef<str> + Send + Sync,
60+
) -> Result<async_nats::Subscriber> {
61+
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
62+
Ok(self.drt().nats_client().client().subscribe(subject).await?)
63+
}
64+
65+
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
66+
&self,
67+
event_name: impl AsRef<str> + Send + Sync,
68+
) -> Result<impl Stream<Item = Result<T>> + Send> {
69+
let subscriber = self.subscribe(event_name).await?;
70+
71+
// Transform the subscriber into a stream of deserialized events
72+
let stream = subscriber.map(move |msg| {
73+
serde_json::from_slice::<T>(&msg.payload)
74+
.with_context(|| format!("Failed to deserialize event payload: {:?}", msg.payload))
75+
});
76+
77+
Ok(stream)
78+
}
79+
}
80+
81+
#[cfg(feature = "integration")]
82+
#[cfg(test)]
83+
mod tests {
84+
use super::*;
85+
86+
// todo - make a distributed runtime fixture
87+
// todo - two options - fully mocked or integration test
88+
#[tokio::test]
89+
async fn test_publish() {
90+
let rt = Runtime::from_current().unwrap();
91+
let dtr = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
92+
let ns = dtr.namespace("test".to_string()).unwrap();
93+
let cp = ns.component("component".to_string()).unwrap();
94+
cp.publish("test", &"test".to_string()).await.unwrap();
95+
rt.shutdown();
96+
}
97+
98+
#[tokio::test]
99+
async fn test_subscribe() {
100+
let rt = Runtime::from_current().unwrap();
101+
let dtr = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
102+
let ns = dtr.namespace("test".to_string()).unwrap();
103+
let cp = ns.component("component".to_string()).unwrap();
104+
105+
// Create a subscriber
106+
let mut subscriber = ns.subscribe("test").await.unwrap();
107+
108+
// Publish a message
109+
cp.publish("test", &"test_message".to_string())
110+
.await
111+
.unwrap();
112+
113+
// Receive the message
114+
if let Some(msg) = subscriber.next().await {
115+
let received = String::from_utf8(msg.payload.to_vec()).unwrap();
116+
assert_eq!(received, "\"test_message\"");
117+
}
118+
119+
rt.shutdown();
120+
}
121+
}

‎lib/runtime/src/component/endpoint.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ impl EndpointConfigBuilder {
113113
let info = ComponentEndpointInfo {
114114
component: endpoint.component.name.clone(),
115115
endpoint: endpoint.name.clone(),
116-
namespace: endpoint.component.namespace.clone(),
116+
namespace: endpoint.component.namespace.name.clone(),
117117
lease_id: lease.id(),
118118
transport: TransportType::NatsTcp(endpoint.subject_to(lease.id())),
119119
};

‎lib/runtime/src/component/namespace.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ mod tests {
101101
let ns = dtr.namespace("test".to_string()).unwrap();
102102

103103
// Create a subscriber
104-
let subscriber = ns.subscribe("test").await.unwrap();
104+
let mut subscriber = ns.subscribe("test").await.unwrap();
105105

106106
// Publish a message
107107
ns.publish("test", &"test_message".to_string())

‎lib/runtime/src/service.rs

+54-6
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,28 @@ impl EndpointInfo {
7373
i64::from_str_radix(id, 16).map_err(|e| error!("Invalid id format: {}", e))
7474
}
7575
}
76+
77+
// TODO: This is _really_ close to the async_nats::service::Stats object,
78+
// but it's missing a few fields like "name", so use a temporary struct
79+
// for easy deserialization. Ideally, this type already exists or can
80+
// be exposed in the library somewhere.
81+
/// Stats structure returned from NATS service API
7682
#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
77-
pub struct Metrics(pub serde_json::Value);
83+
pub struct Metrics {
84+
// Standard NATS Service API fields
85+
pub average_processing_time: f64,
86+
pub last_error: String,
87+
pub num_errors: u64,
88+
pub num_requests: u64,
89+
pub processing_time: u64,
90+
pub queue_group: String,
91+
// Field containing custom stats handler data
92+
pub data: serde_json::Value,
93+
}
7894

7995
impl Metrics {
8096
pub fn decode<T: for<'de> Deserialize<'de>>(self) -> Result<T> {
81-
serde_json::from_value(self.0).map_err(Into::into)
97+
serde_json::from_value(self.data).map_err(Into::into)
8298
}
8399
}
84100

@@ -153,12 +169,28 @@ mod tests {
153169
EndpointInfo {
154170
name: "endpoint1".to_string(),
155171
subject: "subject1".to_string(),
156-
data: Some(Metrics(serde_json::json!({"key": "value1"}))),
172+
data: Some(Metrics {
173+
average_processing_time: 0.1,
174+
last_error: "none".to_string(),
175+
num_errors: 0,
176+
num_requests: 10,
177+
processing_time: 100,
178+
queue_group: "group1".to_string(),
179+
data: serde_json::json!({"key": "value1"}),
180+
}),
157181
},
158182
EndpointInfo {
159183
name: "endpoint2-foo".to_string(),
160184
subject: "subject2".to_string(),
161-
data: Some(Metrics(serde_json::json!({"key": "value1"}))),
185+
data: Some(Metrics {
186+
average_processing_time: 0.1,
187+
last_error: "none".to_string(),
188+
num_errors: 0,
189+
num_requests: 10,
190+
processing_time: 100,
191+
queue_group: "group1".to_string(),
192+
data: serde_json::json!({"key": "value1"}),
193+
}),
162194
},
163195
],
164196
},
@@ -171,12 +203,28 @@ mod tests {
171203
EndpointInfo {
172204
name: "endpoint1".to_string(),
173205
subject: "subject1".to_string(),
174-
data: Some(Metrics(serde_json::json!({"key": "value1"}))),
206+
data: Some(Metrics {
207+
average_processing_time: 0.1,
208+
last_error: "none".to_string(),
209+
num_errors: 0,
210+
num_requests: 10,
211+
processing_time: 100,
212+
queue_group: "group1".to_string(),
213+
data: serde_json::json!({"key": "value1"}),
214+
}),
175215
},
176216
EndpointInfo {
177217
name: "endpoint2-bar".to_string(),
178218
subject: "subject2".to_string(),
179-
data: Some(Metrics(serde_json::json!({"key": "value2"}))),
219+
data: Some(Metrics {
220+
average_processing_time: 0.1,
221+
last_error: "none".to_string(),
222+
num_errors: 0,
223+
num_requests: 10,
224+
processing_time: 100,
225+
queue_group: "group1".to_string(),
226+
data: serde_json::json!({"key": "value2"}),
227+
}),
180228
},
181229
],
182230
},

‎lib/runtime/src/transports/nats.rs

-116
Original file line numberDiff line numberDiff line change
@@ -108,35 +108,6 @@ impl Client {
108108
Ok(subscription)
109109
}
110110

111-
// todo - deprecate - move to service subscriber
112-
pub async fn get_endpoints(
113-
&self,
114-
service_name: &str,
115-
timeout: time::Duration,
116-
) -> Result<Vec<Bytes>, anyhow::Error> {
117-
let subject = format!("$SRV.STATS.{}", service_name);
118-
let reply_subject = format!("_INBOX.{}", nuid::next());
119-
let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
120-
121-
let deadline = tokio::time::Instant::now() + timeout;
122-
123-
// Publish the request with the reply-to subject
124-
self.client
125-
.publish_with_reply(subject, reply_subject, "".into())
126-
.await?;
127-
128-
// Set a timeout to gather responses
129-
let mut responses = Vec::new();
130-
// let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
131-
132-
while let Ok(Some(message)) = time::timeout_at(deadline, subscription.next()).await {
133-
// log::debug!("get endpoint received message before timeout: {:?}", message);
134-
responses.push(message.payload);
135-
}
136-
137-
Ok(responses)
138-
}
139-
140111
// /// create a new stream
141112
// async fn get_or_create_work_queue_stream(
142113
// &self,
@@ -272,35 +243,6 @@ impl Client {
272243
// Ok(())
273244
// }
274245

275-
// pub async fn get_endpoints(
276-
// &self,
277-
// service_name: &str,
278-
// timeout: Duration,
279-
// ) -> Result<Vec<Bytes>, anyhow::Error> {
280-
// let subject = format!("$SRV.STATS.{}", service_name);
281-
// let reply_subject = format!("_INBOX.{}", nuid::next());
282-
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
283-
284-
// // Publish the request with the reply-to subject
285-
// self.client
286-
// .publish_with_reply(subject, reply_subject, "".into())
287-
// .await?;
288-
289-
// // Set a timeout to gather responses
290-
// let mut responses = Vec::new();
291-
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
292-
293-
// let start = time::Instant::now();
294-
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
295-
// responses.push(message.payload);
296-
// if start.elapsed() > timeout {
297-
// break;
298-
// }
299-
// }
300-
301-
// Ok(responses)
302-
// }
303-
304246
// pub fn frontend_client(&self, request_id: String) -> SpecializedClient {
305247
// SpecializedClient::new(self.client.clone(), ClientKind::Frontend, request_id)
306248
// }
@@ -691,35 +633,6 @@ mod tests {
691633
// assert_eq!(initial_work_queue_count, work_queue_count);
692634
// }
693635

694-
// pub async fn get_endpoints(
695-
// &self,
696-
// service_name: &str,
697-
// timeout: Duration,
698-
// ) -> Result<Vec<Bytes>, anyhow::Error> {
699-
// let subject = format!("$SRV.STATS.{}", service_name);
700-
// let reply_subject = format!("_INBOX.{}", nuid::next());
701-
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
702-
703-
// // Publish the request with the reply-to subject
704-
// self.client
705-
// .publish_with_reply(subject, reply_subject, "".into())
706-
// .await?;
707-
708-
// // Set a timeout to gather responses
709-
// let mut responses = Vec::new();
710-
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
711-
712-
// let start = time::Instant::now();
713-
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
714-
// responses.push(message.payload);
715-
// if start.elapsed() > timeout {
716-
// break;
717-
// }
718-
// }
719-
720-
// Ok(responses)
721-
// }
722-
723636
// async fn connect(config: Arc<Config>) -> Result<NatsClient> {
724637
// let client = ClientOptions::builder()
725638
// .server(config.nats_address.clone())
@@ -852,35 +765,6 @@ mod tests {
852765
// pub fn service_builder(&self) -> NatsServiceBuilder {
853766
// self.client.service_builder()
854767
// }
855-
856-
// pub async fn get_endpoints(
857-
// &self,
858-
// service_name: &str,
859-
// timeout: Duration,
860-
// ) -> Result<Vec<Bytes>, anyhow::Error> {
861-
// let subject = format!("$SRV.STATS.{}", service_name);
862-
// let reply_subject = format!("_INBOX.{}", nuid::next());
863-
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
864-
865-
// // Publish the request with the reply-to subject
866-
// self.client
867-
// .publish_with_reply(subject, reply_subject, "".into())
868-
// .await?;
869-
870-
// // Set a timeout to gather responses
871-
// let mut responses = Vec::new();
872-
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
873-
874-
// let start = tokio::time::Instant::now();
875-
// while let Ok(Some(message)) = tokio::time::timeout(timeout, subscription.next()).await {
876-
// responses.push(message.payload);
877-
// if start.elapsed() > timeout {
878-
// break;
879-
// }
880-
// }
881-
882-
// Ok(responses)
883-
// }
884768
// }
885769

886770
// #[derive(Debug, Clone, Serialize, Deserialize)]

0 commit comments

Comments
 (0)
Please sign in to comment.