@@ -201,115 +201,90 @@ def _set_links(self, span: Span):
201
201
instance = _extract_bound (instance )
202
202
parent_span = _get_nearest_llmobs_ancestor (span )
203
203
204
- step_idx = self ._set_input_links (instance , span , parent_span )
204
+ prev_traced_step_idx = self ._set_input_links (instance , span , parent_span )
205
205
206
- self ._set_output_links (span , parent_span , step_idx )
206
+ self ._set_output_links (span , parent_span , prev_traced_step_idx )
207
207
208
208
def _set_input_links (self , instance : Any , span : Span , parent_span : Union [Span , None ]) -> int :
209
209
"""
210
210
Sets input links (to: input) on the given span
211
211
1. If the instance associated with the span is not a step in a chain, link from its parent span (input->input)
212
212
2. If the instance associated with the span is a step in a chain, link from the last traced step in the chain
213
213
a. This could be multiple steps, if the last step was a RunnableParallel
214
- b. In this case, it would be an output->input relationship
214
+ b. If there was no previous traced step, link from the parent span (input->input)
215
+ b. Otherwise, it would be an output->input relationship with the previously traced span(s)
215
216
"""
216
217
if parent_span is None :
217
218
return - 1
218
219
219
220
is_step = id (instance ) in self ._chain_steps
220
221
221
- # defaults
222
- invoker_spans = [parent_span ]
223
- invoker_links_attributes = [{"from" : "input" , "to" : "input" }]
224
- has_parallel_steps = False
225
- step_idx = - 1
226
-
227
- links = []
228
-
229
222
if not is_step :
230
- self ._set_span_links (
231
- span ,
232
- [
233
- {
234
- "trace_id" : "{:x}" .format (span .trace_id ),
235
- "span_id" : str (invoker_spans [0 ].span_id ),
236
- "attributes" : invoker_links_attributes [0 ],
237
- }
238
- ],
239
- )
223
+ self ._set_span_links (span , [parent_span ], "input" , "input" )
240
224
241
- return step_idx
225
+ return - 1
242
226
243
- chain_instance = _extract_bound (self ._instances .get (invoker_spans [ 0 ] ))
227
+ chain_instance = _extract_bound (self ._instances .get (parent_span ))
244
228
steps = getattr (chain_instance , "steps" , [])
245
229
flatmap_chain_steps = _flattened_chain_steps (steps )
246
- for i , step in enumerate (flatmap_chain_steps ):
247
- if id (step ) == id (instance ) or (
248
- isinstance (step , list ) and any (id (sub_step ) == id (instance ) for sub_step in step )
249
- ):
250
- step_idx = i
251
- break
252
- for i in range (step_idx - 1 , - 1 , - 1 ):
253
- step = flatmap_chain_steps [i ]
254
- if id (step ) in self ._spans :
255
- invoker_span = self ._spans [id (step )]
256
- invoker_link_attributes = {"from" : "output" , "to" : "input" }
257
- break
258
- if isinstance (step , list ): # parallel steps in the list
259
- for parallel_step in step :
260
- if id (parallel_step ) in self ._spans :
261
- if not has_parallel_steps :
262
- invoker_spans = []
263
- invoker_links_attributes = []
264
- has_parallel_steps = True
265
-
266
- invoker_spans .append (self ._spans [id (parallel_step )])
267
- invoker_links_attributes .append ({"from" : "output" , "to" : "input" })
268
- break
269
-
270
- for link_data in zip (invoker_spans , invoker_links_attributes ):
271
- invoker_span , invoker_link_attributes = link_data
272
- if invoker_span is None :
273
- continue
274
- links .append (
275
- {
276
- "trace_id" : "{:x}" .format (span .trace_id ),
277
- "span_id" : str (invoker_span .span_id ),
278
- "attributes" : invoker_link_attributes ,
279
- }
280
- )
230
+ prev_traced_step_idx = self ._find_previous_traced_step_index (instance , flatmap_chain_steps )
281
231
282
- self ._set_span_links (span , links )
232
+ if prev_traced_step_idx == - 1 :
233
+ self ._set_span_links (span , [parent_span ], "input" , "input" )
283
234
284
- return step_idx
235
+ return prev_traced_step_idx
285
236
286
- def _set_output_links (self , span : Span , parent_span : Union [Span , None ], step_idx : int ) -> None :
237
+ invoker_spans = []
238
+ prev_traced_step = flatmap_chain_steps [prev_traced_step_idx ]
239
+ if isinstance (prev_traced_step , list ):
240
+ for parallel_step in prev_traced_step :
241
+ if id (parallel_step ) in self ._spans :
242
+ invoker_spans .append (self ._spans [id (parallel_step )])
243
+ else :
244
+ invoker_spans .append (self ._spans [id (prev_traced_step )])
245
+
246
+ self ._set_span_links (span , invoker_spans , "output" , "input" )
247
+
248
+ return prev_traced_step_idx
249
+
250
+ def _find_previous_traced_step_index (self , instance , flatmap_chain_steps ):
251
+ """
252
+ Finds the index in the list of steps of the last traced step in the chain before the current instance.
253
+ """
254
+ curr_idx = 0
255
+ curr_step = flatmap_chain_steps [0 ]
256
+ prev_traced_step_idx = - 1
257
+
258
+ while (
259
+ curr_idx < len (flatmap_chain_steps )
260
+ and id (curr_step ) != id (instance )
261
+ and not (isinstance (curr_step , list ) and any (id (sub_step ) == id (instance ) for sub_step in curr_step ))
262
+ ):
263
+ if id (curr_step ) in self ._spans or (
264
+ isinstance (curr_step , list ) and any (id (sub_step ) in self ._spans for sub_step in curr_step )
265
+ ):
266
+ prev_traced_step_idx = curr_idx
267
+ curr_idx += 1
268
+ curr_step = flatmap_chain_steps [curr_idx ]
269
+
270
+ return prev_traced_step_idx
271
+
272
+ def _set_output_links (self , span : Span , parent_span : Union [Span , None ], prev_traced_step_idx : int ) -> None :
287
273
"""
288
274
Sets the output links for the parent span of the given span (to: output)
289
275
This is done by removing repeated span links from steps in a chain.
290
- We add output->output span links at every step
276
+ We add output->output span links at every step.
291
277
"""
292
278
if parent_span is None :
293
279
return
294
280
295
281
parent_links = parent_span ._get_ctx_item (SPAN_LINKS ) or []
296
- pop_indecies = self ._get_popped_span_link_indecies (parent_span , parent_links , step_idx )
297
- parent_links = [link for i , link in enumerate (parent_links ) if i not in pop_indecies ]
298
-
299
- parent_span ._set_ctx_item (
300
- SPAN_LINKS ,
301
- parent_links
302
- + [
303
- {
304
- "trace_id" : "{:x}" .format (span .trace_id ),
305
- "span_id" : str (span .span_id ),
306
- "attributes" : {"from" : "output" , "to" : "output" },
307
- }
308
- ],
309
- )
282
+ pop_indecies = self ._get_popped_span_link_indecies (parent_span , parent_links , prev_traced_step_idx )
283
+
284
+ self ._set_span_links (parent_span , [span ], "output" , "output" , popped_span_link_indecies = pop_indecies )
310
285
311
286
def _get_popped_span_link_indecies (
312
- self , parent_span : Span , parent_links : List [Dict [str , Any ]], step_idx : int
287
+ self , parent_span : Span , parent_links : List [Dict [str , Any ]], prev_traced_step_idx : int
313
288
) -> List [int ]:
314
289
"""
315
290
Returns a list of indecies to pop from the parent span links list
@@ -321,7 +296,7 @@ def _get_popped_span_link_indecies(
321
296
"""
322
297
pop_indecies : List [int ] = []
323
298
parent_instance = self ._instances .get (parent_span )
324
- if not parent_instance :
299
+ if not parent_instance or prev_traced_step_idx == - 1 :
325
300
return pop_indecies
326
301
327
302
parent_instance = _extract_bound (parent_instance )
@@ -330,33 +305,47 @@ def _get_popped_span_link_indecies(
330
305
331
306
steps = getattr (parent_instance , "steps" , [])
332
307
flatmap_chain_steps = _flattened_chain_steps (steps )
333
- for i in range (step_idx - 1 , - 1 , - 1 ):
334
- step = flatmap_chain_steps [i ]
335
- if id (step ) in self ._spans :
336
- invoker_span_id = self ._spans [id (step )].span_id
337
- link_idx = next (
338
- (i for i , link in enumerate (parent_links ) if link ["span_id" ] == str (invoker_span_id )), None
339
- )
340
- if link_idx is not None :
341
- pop_indecies .append (link_idx )
342
- break
343
- if isinstance (step , list ): # parallel steps in the list
344
- for parallel_step in step :
345
- if id (parallel_step ) in self ._spans :
346
- invoker_span_id = self ._spans [id (parallel_step )].span_id
347
- link_idx = next (
348
- (i for i , link in enumerate (parent_links ) if link ["span_id" ] == str (invoker_span_id )),
349
- None ,
350
- )
351
- if link_idx is not None :
352
- pop_indecies .append (link_idx )
353
- break
308
+ prev_traced_step = flatmap_chain_steps [prev_traced_step_idx ]
309
+
310
+ if isinstance (prev_traced_step , list ):
311
+ for parallel_step in prev_traced_step :
312
+ if id (parallel_step ) in self ._spans :
313
+ invoker_span_id = self ._spans [id (parallel_step )].span_id
314
+ link_idx = next (
315
+ (i for i , link in enumerate (parent_links ) if link ["span_id" ] == str (invoker_span_id )), None
316
+ )
317
+ if link_idx is not None :
318
+ pop_indecies .append (link_idx )
319
+ else :
320
+ invoker_span_id = self ._spans [id (prev_traced_step )].span_id
321
+ link_idx = next ((i for i , link in enumerate (parent_links ) if link ["span_id" ] == str (invoker_span_id )), None )
322
+ if link_idx is not None :
323
+ pop_indecies .append (link_idx )
354
324
355
325
return pop_indecies
356
326
357
- def _set_span_links (self , span : Span , links : List [Dict [str , Any ]]) -> None :
327
+ def _set_span_links (
328
+ self ,
329
+ span : Span ,
330
+ from_spans : List [Span ],
331
+ link_from : str ,
332
+ link_to : str ,
333
+ popped_span_link_indecies : Optional [List [int ]] = None ,
334
+ ) -> None :
358
335
"""Sets the span links on the given span along with the existing links."""
359
336
existing_links = span ._get_ctx_item (SPAN_LINKS ) or []
337
+
338
+ if popped_span_link_indecies :
339
+ existing_links = [link for i , link in enumerate (existing_links ) if i not in popped_span_link_indecies ]
340
+
341
+ links = [
342
+ {
343
+ "trace_id" : "{:x}" .format (from_span .trace_id ),
344
+ "span_id" : str (from_span .span_id ),
345
+ "attributes" : {"from" : link_from , "to" : link_to },
346
+ }
347
+ for from_span in from_spans
348
+ ]
360
349
span ._set_ctx_item (SPAN_LINKS , existing_links + links )
361
350
362
351
def _llmobs_set_metadata (self , span : Span , model_provider : Optional [str ] = None ) -> None :
0 commit comments