Skip to content

Commit

Permalink
customize prompt settings
Browse files Browse the repository at this point in the history
  • Loading branch information
lpinheiroms committed Jan 16, 2025
1 parent 8f3c42c commit 96cd3d3
Showing 1 changed file with 101 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class SKChatCompletionAdapter(ChatCompletionClient):
- Pass in a `Kernel` and any supported Semantic Kernel `ChatCompletionClientBase` connector.
- Provide tools (via Autogen `Tool` or `ToolSchema`) for function calls during chat completion.
- Stream responses or retrieve them in a single request.
- Provide prompt settings to control the chat completion behavior either globally through the constructor
or on a per-request basis through the `extra_create_args` dictionary.
Args:
sk_client (ChatCompletionClientBase):
Expand All @@ -50,10 +52,12 @@ class SKChatCompletionAdapter(ChatCompletionClient):
import asyncio
from semantic_kernel import Kernel
from semantic_kernel.memory.null_memory import NullMemory
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
from semantic_kernel.connectors.ai.google.google_ai import GoogleAIChatCompletion
from semantic_kernel.connectors.ai.ollama import OllamaChatCompletion
from semantic_kernel.connectors.ai.ollama.ollama_prompt_execution_settings import OllamaChatPromptExecutionSettings
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import (
AzureChatCompletion,
AzureChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.google.google_ai import GoogleAIChatCompletion,
from semantic_kernel.connectors.ai.ollama import OllamaChatCompletion, OllamaChatPromptExecutionSettings
from autogen_core.models import SystemMessage, UserMessage, LLMMessage
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter
from autogen_core import CancellationToken
Expand Down Expand Up @@ -96,7 +100,10 @@ async def main():
api_key = "<AZURE_OPENAI_API_KEY>"
azure_client = AzureChatCompletion(deployment_name=deployment_name, endpoint=endpoint, api_key=api_key)
azure_adapter = SKChatCompletionAdapter(sk_client=azure_client)
azure_request_settings = AzureChatPromptExecutionSettings(
options={"temperature": 0.8},
)
azure_adapter = SKChatCompletionAdapter(sk_client=azure_client, default_prompt_settings=azure_request_settings)
# ----------------------------------------------------------------
# Example B: Google Gemini
Expand All @@ -114,7 +121,15 @@ async def main():
host="http://localhost:11434",
ai_model_id="llama3.1",
)
ollama_adapter = SKChatCompletionAdapter(sk_client=ollama_client)
request_settings = OllamaChatPromptExecutionSettings(
# For model specific settings, specify them in the options dictionary.
# For more information on the available options, refer to the Ollama API documentation:
# https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
options={
"temperature": 0.8,
},
)
ollama_adapter = SKChatCompletionAdapter(sk_client=ollama_client, default_prompt_settings=request_settings)
# 3) Create a tool and register it with the kernel
calc_tool = CalculatorTool()
Expand All @@ -126,9 +141,13 @@ async def main():
]
# 5) Invoke chat completion with the Azure adapter (as an example)
# Provide the kernel in extra_create_args, and pass the tool.
# Provide the kernel in extra_create_args, and pass the tool and optional prompt settings.
# The same pattern applies to Google or Ollama adapters.
result = await azure_adapter.create(messages=messages, tools=[calc_tool], extra_create_args={"kernel": kernel})
result = await azure_adapter.create(
messages=messages,
tools=[calc_tool],
extra_create_args={"kernel": kernel, "prompt_execution_settings": azure_request_settings},
)
# Print or use the result
print("Result content:", result.content)
Expand All @@ -143,7 +162,15 @@ async def main():
asyncio.run(main())
"""

def __init__(self, sk_client: ChatCompletionClientBase, model_info: Optional[ModelInfo] = None):
def __init__(
self,
sk_client: ChatCompletionClientBase,
model_info: Optional[ModelInfo] = None,
service_id: Optional[str] = None,
default_prompt_settings: Optional[PromptExecutionSettings] = None,
):
self._service_id = service_id
self._default_prompt_settings = default_prompt_settings
self._sk_client = sk_client
self._model_info = model_info or ModelInfo(
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
Expand Down Expand Up @@ -181,23 +208,26 @@ def _convert_to_chat_history(self, messages: Sequence[LLMMessage]) -> ChatHistor
return chat_history

def _build_execution_settings(
self, extra_create_args: Mapping[str, Any], tools: Sequence[Tool | ToolSchema]
self, default_prompt_settings: Optional[PromptExecutionSettings], tools: Sequence[Tool | ToolSchema]
) -> PromptExecutionSettings:
"""Build PromptExecutionSettings from extra_create_args"""
# Extract service_id if provided, otherwise use None
service_id = extra_create_args.get("service_id")

if default_prompt_settings is not None:
prompt_args: dict[str, Any] = default_prompt_settings.prepare_settings_dict() # type: ignore
else:
prompt_args = {}

# If tools are available, configure function choice behavior with auto_invoke disabled
function_choice_behavior = None
if tools:
function_choice_behavior = FunctionChoiceBehavior.Auto( # type: ignore
auto_invoke=extra_create_args.get("auto_invoke", False)
auto_invoke=False
)

# Create settings with remaining args as extension_data
settings = PromptExecutionSettings(
service_id=service_id,
extension_data=dict(extra_create_args),
service_id=self._service_id,
extension_data=prompt_args,
function_choice_behavior=function_choice_behavior,
)

Expand Down Expand Up @@ -260,6 +290,30 @@ async def create(
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
"""
Create a chat completion using the Semantic Kernel client.
The `extra_create_args` dictionary can include two special keys:
1) `"kernel"` (required):
An instance of :class:`semantic_kernel.Kernel` used to execute the request.
If not provided, a ValueError is raised.
2) `"prompt_execution_settings"` (optional):
An instance of a :class:`PromptExecutionSettings` subclass corresponding to the
underlying Semantic Kernel client (e.g., `AzureChatPromptExecutionSettings`,
`GoogleAIChatPromptExecutionSettings`). If not provided, the adapter's default
prompt settings will be used.
Args:
messages: The list of LLM messages to send.
tools: The tools that may be invoked during the chat.
json_output: Whether the model is expected to return JSON.
extra_create_args: Additional arguments to control the chat completion behavior.
cancellation_token: Token allowing cancellation of the request.
Returns:
CreateResult: The result of the chat completion.
"""
if "kernel" not in extra_create_args:
raise ValueError("kernel is required in extra_create_args")

Expand All @@ -270,7 +324,10 @@ async def create(
chat_history = self._convert_to_chat_history(messages)

# Build execution settings from extra args and tools
settings = self._build_execution_settings(extra_create_args, tools)
user_settings = extra_create_args.get("prompt_execution_settings", None)
if user_settings is None:
user_settings = self._default_prompt_settings
settings = self._build_execution_settings(user_settings, tools)

# Sync tools with kernel
self._sync_tools_with_kernel(kernel, tools)
Expand Down Expand Up @@ -313,6 +370,30 @@ async def create_stream(
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""
Create a streaming chat completion using the Semantic Kernel client.
The `extra_create_args` dictionary can include two special keys:
1) `"kernel"` (required):
An instance of :class:`semantic_kernel.Kernel` used to execute the request.
If not provided, a ValueError is raised.
2) `"prompt_execution_settings"` (optional):
An instance of a :class:`PromptExecutionSettings` subclass corresponding to the
underlying Semantic Kernel client (e.g., `AzureChatPromptExecutionSettings`,
`GoogleAIChatPromptExecutionSettings`). If not provided, the adapter's default
prompt settings will be used.
Args:
messages: The list of LLM messages to send.
tools: The tools that may be invoked during the chat.
json_output: Whether the model is expected to return JSON.
extra_create_args: Additional arguments to control the chat completion behavior.
cancellation_token: Token allowing cancellation of the request.
Yields:
Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls.
"""
if "kernel" not in extra_create_args:
raise ValueError("kernel is required in extra_create_args")

Expand All @@ -321,7 +402,10 @@ async def create_stream(
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")

chat_history = self._convert_to_chat_history(messages)
settings = self._build_execution_settings(extra_create_args, tools)
user_settings = extra_create_args.get("prompt_execution_settings", None)
if user_settings is None:
user_settings = self._default_prompt_settings
settings = self._build_execution_settings(user_settings, tools)
self._sync_tools_with_kernel(kernel, tools)

prompt_tokens = 0
Expand Down

0 comments on commit 96cd3d3

Please sign in to comment.