forked from NVIDIA/NeMo-Guardrails
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllmrails.py
1037 lines (874 loc) · 41.7 KB
/
llmrails.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLM Rails entry point."""
import asyncio
import importlib.util
import logging
import os
import re
import threading
import time
import warnings
from typing import Any, AsyncIterator, List, Optional, Tuple, Type, Union, cast
from langchain.llms.base import BaseLLM
from nemoguardrails.actions.llm.generation import LLMGenerationActions
from nemoguardrails.actions.llm.utils import get_colang_history
from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx
from nemoguardrails.colang import parse_colang_file
from nemoguardrails.colang.v1_0.runtime.flows import compute_context
from nemoguardrails.colang.v1_0.runtime.runtime import Runtime, RuntimeV1_0
from nemoguardrails.colang.v2_x.lang.utils import new_uuid
from nemoguardrails.colang.v2_x.runtime.flows import Action, State
from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x
from nemoguardrails.colang.v2_x.runtime.serialization import (
json_to_state,
state_to_json,
)
from nemoguardrails.context import (
explain_info_var,
generation_options_var,
llm_stats_var,
streaming_handler_var,
)
from nemoguardrails.embeddings.index import EmbeddingsIndex
from nemoguardrails.kb.kb import KnowledgeBase
from nemoguardrails.llm.providers import get_llm_provider, get_llm_provider_names
from nemoguardrails.logging.explain import ExplainInfo
from nemoguardrails.logging.processing_log import compute_generation_log
from nemoguardrails.logging.stats import LLMStats
from nemoguardrails.logging.verbose import set_verbose
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig
from nemoguardrails.rails.llm.options import (
GenerationLog,
GenerationOptions,
GenerationResponse,
)
from nemoguardrails.rails.llm.utils import get_history_cache_key
from nemoguardrails.streaming import StreamingHandler
from nemoguardrails.utils import get_or_create_event_loop, new_event_dict
log = logging.getLogger(__name__)
process_events_semaphore = asyncio.Semaphore(1)
class LLMRails:
"""Rails based on a given configuration."""
config: RailsConfig
llm: Optional[BaseLLM]
runtime: Runtime
def __init__(
self, config: RailsConfig, llm: Optional[BaseLLM] = None, verbose: bool = False
):
"""Initializes the LLMRails instance.
Args:
config: A rails configuration.
llm: An optional LLM engine to use.
verbose: Whether the logging should be verbose or not.
"""
self.config = config
self.llm = llm
self.verbose = verbose
if self.verbose:
set_verbose(True, llm_calls=True)
# We allow the user to register additional embedding search providers, so we keep
# an index of them.
self.embedding_search_providers = {}
# The default embeddings model is using FastEmbed
self.default_embedding_model = "all-MiniLM-L6-v2"
self.default_embedding_engine = "FastEmbed"
# We keep a cache of the events history associated with a sequence of user messages.
# TODO: when we update the interface to allow to return a "state object", this
# should be removed
self.events_history_cache = {}
# Weather the main LLM supports streaming
self.main_llm_supports_streaming = False
# We also load the default flows from the `default_flows.yml` file in the current folder.
# But only for version 1.0.
# TODO: decide on the default flows for 2.x.
if config.colang_version == "1.0":
# We also load the default flows from the `llm_flows.co` file in the current folder.
current_folder = os.path.dirname(__file__)
default_flows_file = "llm_flows.co"
default_flows_path = os.path.join(current_folder, default_flows_file)
with open(default_flows_path, "r") as f:
default_flows_content = f.read()
default_flows = parse_colang_file(
default_flows_file, default_flows_content
)["flows"]
# We mark all the default flows as system flows.
for flow_config in default_flows:
flow_config["is_system_flow"] = True
# We add the default flows to the config.
self.config.flows.extend(default_flows)
# We also need to load the content from the components library.
library_path = os.path.join(os.path.dirname(__file__), "../../library")
for root, dirs, files in os.walk(library_path):
for file in files:
# Extract the full path for the file
full_path = os.path.join(root, file)
if file.endswith(".co"):
with open(full_path, "r", encoding="utf-8") as f:
content = parse_colang_file(
file, content=f.read(), version=config.colang_version
)
# We mark all the flows coming from the guardrails library as system flows.
for flow_config in content["flows"]:
flow_config["is_system_flow"] = True
# We load all the flows
self.config.flows.extend(content["flows"])
# And all the messages as well, if they have not been overwritten
for message_id, utterances in content.get(
"bot_messages", {}
).items():
if message_id not in self.config.bot_messages:
self.config.bot_messages[message_id] = utterances
# Last but not least, we mark all the flows that are used in any of the rails
# as system flows (so they don't end up in the prompt).
rail_flow_ids = (
config.rails.input.flows
+ config.rails.output.flows
+ config.rails.retrieval.flows
)
for flow_config in self.config.flows:
if flow_config.get("id") in rail_flow_ids:
flow_config["is_system_flow"] = True
# We also mark them as subflows by default, to simplify the syntax
flow_config["is_subflow"] = True
# We check if the configuration or any of the imported ones have config.py modules.
config_modules = []
for _path in list(self.config.imported_paths.values()) + [
self.config.config_path
]:
if _path:
filepath = os.path.join(_path, "config.py")
if os.path.exists(filepath):
filename = os.path.basename(filepath)
spec = importlib.util.spec_from_file_location(filename, filepath)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
config_modules.append(config_module)
# First, we initialize the runtime.
if config.colang_version == "1.0":
self.runtime = RuntimeV1_0(config=config, verbose=verbose)
elif config.colang_version == "2.x":
self.runtime = RuntimeV2_x(config=config, verbose=verbose)
else:
raise ValueError(f"Unsupported colang version: {config.colang_version}.")
# If we have a config_modules with an `init` function, we call it.
# We need to call this here because the `init` might register additional
# LLM providers.
for config_module in config_modules:
if hasattr(config_module, "init"):
config_module.init(self)
# If we have a customized embedding model, we'll use it.
for model in self.config.models:
if model.type == "embeddings":
self.default_embedding_model = model.model
self.default_embedding_engine = model.engine
break
# We run some additional checks on the config
self._validate_config()
# Next, we initialize the LLM engines (main engine and action engines if specified).
self._init_llms()
# Next, we initialize the LLM Generate actions and register them.
llm_generation_actions_class = (
LLMGenerationActions
if config.colang_version == "1.0"
else LLMGenerationActionsV2dotx
)
self.llm_generation_actions = llm_generation_actions_class(
config=config,
llm=self.llm,
llm_task_manager=self.runtime.llm_task_manager,
get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance,
verbose=verbose,
)
# If there's already an action registered, we don't override.
self.runtime.register_actions(self.llm_generation_actions, override=False)
# Next, we initialize the Knowledge Base
# There are still some edge cases not covered by nest_asyncio.
# Using a separate thread always for now.
loop = get_or_create_event_loop()
if True or check_sync_call_from_async_loop():
t = threading.Thread(target=asyncio.run, args=(self._init_kb(),))
t.start()
t.join()
else:
loop.run_until_complete(self._init_kb())
# We also register the kb as a parameter that can be passed to actions.
self.runtime.register_action_param("kb", self.kb)
# Reference to the general ExplainInfo object.
self.explain_info = None
def update_llm(self, llm):
"""Replace the main LLM with the provided one.
Arguments:
llm: The new LLM that should be used.
"""
self.llm = llm
self.llm_generation_actions.llm = llm
self.runtime.register_action_param("llm", llm)
def _validate_config(self):
"""Runs additional validation checks on the config."""
existing_flows_names = set([flow.get("id") for flow in self.config.flows])
for flow_name in self.config.rails.input.flows:
if flow_name not in existing_flows_names:
raise ValueError(
f"The provided input rail flow `{flow_name}` does not exist"
)
for flow_name in self.config.rails.output.flows:
if flow_name not in existing_flows_names:
raise ValueError(
f"The provided output rail flow `{flow_name}` does not exist"
)
for flow_name in self.config.rails.retrieval.flows:
if flow_name not in existing_flows_names:
raise ValueError(
f"The provided retrieval rail flow `{flow_name}` does not exist"
)
# If both passthrough mode and single call mode are specified, we raise an exception.
if self.config.passthrough and self.config.rails.dialog.single_call.enabled:
raise ValueError(
f"The passthrough mode and the single call dialog rails mode can't be used at the same time. "
f"The single call mode needs to use an altered prompt when prompting the LLM. "
)
async def _init_kb(self):
"""Initializes the knowledge base."""
self.kb = None
if not self.config.docs:
return
documents = [doc.content for doc in self.config.docs]
self.kb = KnowledgeBase(
documents=documents,
config=self.config.knowledge_base,
get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance,
)
self.kb.init()
await self.kb.build()
def _init_llms(self):
"""
Initializes the right LLM engines based on the configuration.
There can be multiple LLM engines and types that can be specified in the config.
The main LLM engine is the one that will be used for all the core guardrails generations.
Other LLM engines can be specified for use in specific actions.
The reason we provide an option for decoupling the main LLM engine from the action LLM
is to allow for flexibility in using specialized LLM engines for specific actions.
"""
# If we already have a pre-configured one,
# we just need to register the LLM as an action param.
if self.llm is not None:
self.runtime.register_action_param("llm", self.llm)
return
for llm_config in self.config.models:
if llm_config.type == "embeddings":
pass
else:
if llm_config.engine not in get_llm_provider_names():
msg = f"Unknown LLM engine: {llm_config.engine}."
if llm_config.engine == "openai":
msg += " Please install langchain-openai using `pip install langchain-openai`."
raise Exception(msg)
provider_cls = get_llm_provider(llm_config)
# We need to compute the kwargs for initializing the LLM
kwargs = llm_config.parameters
# We also need to pass the model, if specified
if llm_config.model:
# Some LLM providers use `model_name` instead of model. For backward compatibility
# we keep this hard-coded mapping.
if llm_config.engine in [
"azure",
"openai",
"gooseai",
"nlpcloud",
"petals",
"trt_llm",
"vertexai",
]:
kwargs["model_name"] = llm_config.model
elif llm_config.engine == "nvidia_ai_endpoints":
kwargs["model"] = llm_config.model
else:
# The `__fields__` attribute is computed dynamically by pydantic.
if "model" in provider_cls.__fields__:
kwargs["model"] = llm_config.model
if self.config.streaming:
if "streaming" in provider_cls.__fields__:
kwargs["streaming"] = True
self.main_llm_supports_streaming = True
else:
log.warning(
f"The provider {provider_cls.__name__} does not support streaming."
)
if llm_config.type == "main" or len(self.config.models) == 1:
self.llm = provider_cls(**kwargs)
self.runtime.register_action_param("llm", self.llm)
else:
model_name = f"{llm_config.type}_llm"
setattr(self, model_name, provider_cls(**kwargs))
self.runtime.register_action_param(
model_name, getattr(self, model_name)
)
def _get_embeddings_search_provider_instance(
self, esp_config: Optional[EmbeddingSearchProvider] = None
) -> EmbeddingsIndex:
if esp_config is None:
esp_config = EmbeddingSearchProvider()
if esp_config.name == "default":
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
return BasicEmbeddingsIndex(
embedding_model=esp_config.parameters.get(
"embedding_model", self.default_embedding_model
),
embedding_engine=esp_config.parameters.get(
"embedding_engine", self.default_embedding_engine
),
cache_config=esp_config.cache,
# We make sure we also pass additional relevant params.
**{
k: v
for k, v in esp_config.parameters.items()
if k in ["use_batching", "max_batch_size", "matx_batch_hold"]
and v is not None
},
)
else:
if esp_config.name not in self.embedding_search_providers:
raise Exception(f"Unknown embedding search provider: {esp_config.name}")
else:
kwargs = esp_config.parameters
return self.embedding_search_providers[esp_config.name](**kwargs)
def _get_events_for_messages(self, messages: List[dict], state: Any):
"""Return the list of events corresponding to the provided messages.
Tries to find a prefix of messages for which we have already a list of events
in the cache. For the rest, they are converted as is.
The reason this cache exists is that we want to benefit from events generated in
previous turns, which can't be computed again because it would be expensive (e.g.,
involving multiple LLM calls).
When an explicit state object will be added, this mechanism can be removed.
Args:
messages: The list of messages.
Returns:
A list of events.
"""
events = []
if self.config.colang_version == "1.0":
# We try to find the longest prefix of messages for which we have a cache
# of events.
p = len(messages) - 1
while p > 0:
cache_key = get_history_cache_key(messages[0:p])
if cache_key in self.events_history_cache:
events = self.events_history_cache[cache_key].copy()
break
p -= 1
# For the rest of the messages, we transform them directly into events.
# TODO: Move this to separate function once more types of messages are supported.
for idx in range(p, len(messages)):
msg = messages[idx]
if msg["role"] == "user":
events.append(
{
"type": "UtteranceUserActionFinished",
"final_transcript": msg["content"],
}
)
# If it's not the last message, we also need to add the `UserMessage` event
if idx != len(messages) - 1:
events.append(
{
"type": "UserMessage",
"text": msg["content"],
}
)
elif msg["role"] == "assistant":
action_uid = new_uuid()
start_event = new_event_dict(
"StartUtteranceBotAction",
script=msg["content"],
action_uid=action_uid,
)
finished_event = new_event_dict(
"UtteranceBotActionFinished",
final_script=msg["content"],
is_success=True,
action_uid=action_uid,
)
events.extend([start_event, finished_event])
elif msg["role"] == "context":
events.append({"type": "ContextUpdate", "data": msg["content"]})
elif msg["role"] == "event":
events.append(msg["event"])
else:
for idx in range(len(messages)):
msg = messages[idx]
if msg["role"] == "user":
events.append(
{
"type": "UtteranceUserActionFinished",
"final_transcript": msg["content"],
}
)
elif msg["role"] == "assistant":
raise ValueError(
"Providing `assistant` messages as input is not supported for Colang 2.0 configurations."
)
elif msg["role"] == "context":
events.append({"type": "ContextUpdate", "data": msg["content"]})
elif msg["role"] == "event":
events.append(msg["event"])
elif msg["role"] == "tool":
action_uid = msg["tool_call_id"]
return_value = msg["content"]
action: Action = state.actions[action_uid]
events.append(
new_event_dict(
f"{action.name}Finished",
action_uid=action_uid,
action_name=action.name,
status="success",
is_success=True,
return_value=return_value,
events=[],
)
)
return events
async def generate_async(
self,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
options: Optional[Union[dict, GenerationOptions]] = None,
state: Optional[Union[dict, State]] = None,
streaming_handler: Optional[StreamingHandler] = None,
return_context: bool = False,
) -> Union[str, dict, GenerationResponse, Tuple[dict, dict]]:
"""Generate a completion or a next message.
The format for messages is the following:
```python
[
{"role": "context", "content": {"user_name": "John"}},
{"role": "user", "content": "Hello! How are you?"},
{"role": "assistant", "content": "I am fine, thank you!"},
{"role": "event", "event": {"type": "UserSilent"}},
...
]
```
Args:
prompt: The prompt to be used for completion.
messages: The history of messages to be used to generate the next message.
options: Options specific for the generation.
state: The state object that should be used as the starting point.
streaming_handler: If specified, and the config supports streaming, the
provided handler will be used for streaming.
return_context: Whether to return the context at the end of the run.
Returns:
The completion (when a prompt is provided) or the next message.
System messages are not yet supported."""
# If a state object is specified, then we switch to "generation options" mode.
# This is because we want the output to be a GenerationResponse which will contain
# the output state.
if state is not None:
# We deserialize the state if needed.
if isinstance(state, dict) and state.get("version", "1.0") == "2.x":
state = json_to_state(state["state"])
if options is None:
options = GenerationOptions()
# We allow options to be specified both as a dict and as an object.
if options and isinstance(options, dict):
options = GenerationOptions(**options)
# Save the generation options in the current async context.
generation_options_var.set(options)
if return_context:
warnings.warn(
"The `return_context` argument is deprecated and will be removed in 0.9.0. "
"Use `GenerationOptions.output_vars = True` instead.",
DeprecationWarning,
stacklevel=2,
)
# And we use the generation options mechanism instead.
if options is None:
options = GenerationOptions()
options.output_vars = True
if streaming_handler:
streaming_handler_var.set(streaming_handler)
# Initialize the object with additional explanation information.
# We allow this to also be set externally. This is useful when multiple parallel
# requests are made.
explain_info = explain_info_var.get()
if explain_info is None:
explain_info = ExplainInfo()
explain_info_var.set(explain_info)
# We also keep a general reference to this object
self.explain_info = explain_info
if prompt is not None:
# Currently, we transform the prompt request into a single turn conversation
messages = [{"role": "user", "content": prompt}]
# If we have generation options, we also add them to the context
if options:
messages = [
{"role": "context", "content": {"generation_options": options.dict()}}
] + messages
# If the last message is from the assistant, rather than the user, then
# we move that to the `$bot_message` variable. This is to enable a more
# convenient interface. (only when dialog rails are disabled)
if (
messages[-1]["role"] == "assistant"
and options
and options.rails.dialog is False
):
# We already have the first message with a context update, so we use that
messages[0]["content"]["bot_message"] = messages[-1]["content"]
messages = messages[0:-1]
# TODO: Add support to load back history of events, next to history of messages
# This is important as without it, the LLM prediction is not as good.
t0 = time.time()
# Initialize the LLM stats
llm_stats = LLMStats()
llm_stats_var.set(llm_stats)
processing_log = []
# The array of events corresponding to the provided sequence of messages.
events = self._get_events_for_messages(messages, state)
if self.config.colang_version == "1.0":
# If we had a state object, we also need to prepend the events from the state.
state_events = []
if state:
assert isinstance(state, dict)
state_events = state["events"]
# Compute the new events.
new_events = await self.runtime.generate_events(
state_events + events, processing_log=processing_log
)
output_state = None
else:
# In generation mode, by default the bot response is an instant action.
instant_actions = ["UtteranceBotAction"]
if self.config.rails.actions.instant_actions is not None:
instant_actions = self.config.rails.actions.instant_actions
# Cast this explicitly to avoid certain warnings
runtime: RuntimeV2_x = cast(RuntimeV2_x, self.runtime)
# Compute the new events.
# In generation mode, the processing is always blocking, i.e., it waits for
# all local actions (sync and async).
new_events, output_state = await runtime.process_events(
events, state=state, instant_actions=instant_actions, blocking=True
)
# We also encode the output state as a JSON
output_state = {"state": state_to_json(output_state), "version": "2.x"}
# Extract and join all the messages from StartUtteranceBotAction events as the response.
responses = []
response_tool_calls = []
response_events = []
new_extra_events = []
# The processing is different for Colang 1.0 and 2.0
if self.config.colang_version == "1.0":
for event in new_events:
if event["type"] == "StartUtteranceBotAction":
# Check if we need to remove a message
if event["script"] == "(remove last message)":
responses = responses[0:-1]
else:
responses.append(event["script"])
else:
for event in new_events:
start_action_match = re.match(r"Start(.*Action)", event["type"])
if start_action_match:
action_name = start_action_match[1]
# TODO: is there an elegant way to extract just the arguments?
arguments = {
k: v
for k, v in event.items()
if k != "type"
and k != "uid"
and k != "event_created_at"
and k != "source_uid"
and k != "action_uid"
}
response_tool_calls.append(
{
"id": event["action_uid"],
"type": "function",
"function": {"name": action_name, "arguments": arguments},
}
)
elif event["type"] == "UtteranceBotActionFinished":
responses.append(event["final_script"])
else:
# We just append the event
response_events.append(event)
new_message = {"role": "assistant", "content": "\n".join(responses)}
if response_tool_calls:
new_message["tool_calls"] = response_tool_calls
if response_events:
new_message["events"] = response_events
if self.config.colang_version == "1.0":
events.extend(new_events)
events.extend(new_extra_events)
# If a state object is not used, then we use the implicit caching
if state is None:
# Save the new events in the history and update the cache
cache_key = get_history_cache_key(messages + [new_message])
self.events_history_cache[cache_key] = events
else:
output_state = {"events": events}
# If logging is enabled, we log the conversation
# TODO: add support for logging flag
explain_info.colang_history = get_colang_history(events)
if self.verbose:
log.info(f"Conversation history so far: \n{explain_info.colang_history}")
total_time = time.time() - t0
log.info(
"--- :: Total processing took %.2f seconds. LLM Stats: %s"
% (total_time, llm_stats)
)
# If there is a streaming handler, we make sure we close it now
streaming_handler = streaming_handler_var.get()
if streaming_handler:
# print("Closing the stream handler explicitly")
await streaming_handler.push_chunk(None)
# If we have generation options, we prepare a GenerationResponse instance.
if options:
# If a prompt was used, we only need to return the content of the message.
if prompt:
res = GenerationResponse(response=new_message["content"])
else:
res = GenerationResponse(response=[new_message])
if self.config.colang_version == "1.0":
# If output variables are specified, we extract their values
if options.output_vars:
context = compute_context(events)
if isinstance(options.output_vars, list):
# If we have only a selection of keys, we filter to only that.
res.output_data = {
k: context.get(k) for k in options.output_vars
}
else:
# Otherwise, we return the full context
res.output_data = context
# If the `return_context` is used, then we return a tuple to keep
# the interface compatible.
# TODO: remove this in 0.10.0.
if return_context:
return new_message, context
_log = compute_generation_log(processing_log)
# Include information about activated rails and LLM calls if requested
if options.log.activated_rails or options.log.llm_calls:
res.log = GenerationLog()
# We always include the stats
res.log.stats = _log.stats
if options.log.activated_rails:
res.log.activated_rails = _log.activated_rails
if options.log.llm_calls:
res.log.llm_calls = []
for activated_rail in _log.activated_rails:
for executed_action in activated_rail.executed_actions:
res.log.llm_calls.extend(executed_action.llm_calls)
# Include internal events if requested
if options.log.internal_events:
if res.log is None:
res.log = GenerationLog()
res.log.internal_events = new_events
# Include the Colang history if requested
if options.log.colang_history:
if res.log is None:
res.log = GenerationLog()
res.log.colang_history = get_colang_history(events)
# Include the raw llm output if requested
if options.llm_output:
# Currently, we include the output from the generation LLM calls.
for activated_rail in _log.activated_rails:
if activated_rail.type == "generation":
for executed_action in activated_rail.executed_actions:
for llm_call in executed_action.llm_calls:
res.llm_output = llm_call.raw_response
else:
if options.output_vars:
raise ValueError(
"The `output_vars` option is not supported for Colang 2.0 configurations."
)
if (
options.log.activated_rails
or options.log.llm_calls
or options.log.internal_events
or options.log.colang_history
):
raise ValueError(
"The `log` option is not supported for Colang 2.0 configurations."
)
if options.llm_output:
raise ValueError(
"The `llm_output` option is not supported for Colang 2.0 configurations."
)
# Include the state
if state is not None:
res.state = output_state
return res
else:
# If a prompt is used, we only return the content of the message.
if prompt:
return new_message["content"]
else:
return new_message
def stream_async(
self,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
) -> AsyncIterator[str]:
"""Simplified interface for getting directly the streamed tokens from the LLM."""
streaming_handler = StreamingHandler()
asyncio.create_task(
self.generate_async(
prompt=prompt,
messages=messages,
streaming_handler=streaming_handler,
)
)
return streaming_handler
def generate(
self,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
return_context: bool = False,
options: Optional[Union[dict, GenerationOptions]] = None,
state: Optional[dict] = None,
):
"""Synchronous version of generate_async."""
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `generate` inside async code. "
"You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`."
)
loop = get_or_create_event_loop()
return loop.run_until_complete(
self.generate_async(
prompt=prompt,
messages=messages,
options=options,
state=state,
return_context=return_context,
)
)
async def generate_events_async(
self,
events: List[dict],
) -> List[dict]:
"""Generate the next events based on the provided history.
The format for events is the following:
```python
[
{"type": "...", ...},
...
]
```
Args:
events: The history of events to be used to generate the next events.
options: The options to be used for the generation.
Returns:
The newly generate event(s).
"""
t0 = time.time()
# Initialize the LLM stats
llm_stats = LLMStats()
llm_stats_var.set(llm_stats)
# Compute the new events.
processing_log = []
new_events = await self.runtime.generate_events(
events, processing_log=processing_log
)
# If logging is enabled, we log the conversation
# TODO: add support for logging flag
if self.verbose:
history = get_colang_history(events)
log.info(f"Conversation history so far: \n{history}")
log.info("--- :: Total processing took %.2f seconds." % (time.time() - t0))
log.info("--- :: Stats: %s" % llm_stats)
return new_events
def generate_events(
self,
events: List[dict],
) -> List[dict]:
"""Synchronous version of `LLMRails.generate_events_async`."""
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `generate_events` inside async code. "
"You should replace with `await generate_events_async(...)` or use `nest_asyncio.apply()`."
)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.generate_events_async(events=events))
async def process_events_async(
self, events: List[dict], state: Optional[dict] = None
) -> Tuple[List[dict], dict]:
"""Process a sequence of events in a given state.
The events will be processed one by one, in the input order.
Args:
events: A sequence of events that needs to be processed.
state: The state that should be used as the starting point. If not provided,
a clean state will be used.
Returns:
(output_events, output_state) Returns a sequence of output events and an output
state.
"""
t0 = time.time()
llm_stats = LLMStats()
llm_stats_var.set(llm_stats)
# Compute the new events.
# We need to protect 'process_events' to be called only once at a time
# TODO (cschueller): Why is this?
async with process_events_semaphore:
output_events, output_state = await self.runtime.process_events(
events, state
)
took = time.time() - t0
# Small tweak, disable this when there were no events (or it was just too fast).
if took > 0.1:
log.info("--- :: Total processing took %.2f seconds." % took)
log.info("--- :: Stats: %s" % llm_stats)
return output_events, output_state
def process_events(
self, events: List[dict], state: Optional[dict] = None
) -> Tuple[List[dict], dict]:
"""Synchronous version of `LLMRails.process_events_async`."""
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `generate_events` inside async code. "
"You should replace with `await generate_events_async(...)."
)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.process_events_async(events, state))
def register_action(self, action: callable, name: Optional[str] = None):
"""Register a custom action for the rails configuration."""