Skip to content

Commit

Permalink
.Net: Function calling stepwise planner improvements (#3857)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
- Move kernel parameter to `ExecuteAsync` instead of constructor
- Use semantic function for initial plan generation, with prompt and
settings loaded from YAML
- Check estimated token count against max tokens specified for planner

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Ben Thomas <[email protected]>
Co-authored-by: Mark Wallace <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2023
1 parent 512b20c commit cb12931
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 96 deletions.
1 change: 1 addition & 0 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "planning", "planning", "{A7
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Extensions", "Extensions", "{3F4E0DC5-2241-4EF2-9F69-E7EC7834D349}"
ProjectSection(SolutionItems) = preProject
src\InternalUtilities\planning\Extensions\ChatHistoryExtensions.cs = src\InternalUtilities\planning\Extensions\ChatHistoryExtensions.cs
src\InternalUtilities\planning\Extensions\KernelFunctionMetadataExtensions.cs = src\InternalUtilities\planning\Extensions\KernelFunctionMetadataExtensions.cs
src\InternalUtilities\planning\Extensions\ReadOnlyFunctionCollectionPlannerExtensions.cs = src\InternalUtilities\planning\Extensions\ReadOnlyFunctionCollectionPlannerExtensions.cs
EndProjectSection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ public static async Task RunAsync()
MaxIterations = 15,
MaxTokens = 4000,
};
var planner = new FunctionCallingStepwisePlanner(kernel, config);
var planner = new FunctionCallingStepwisePlanner(config);

foreach (var question in questions)
{
FunctionCallingStepwisePlannerResult result = await planner.ExecuteAsync(question);
FunctionCallingStepwisePlannerResult result = await planner.ExecuteAsync(kernel, question);
Console.WriteLine($"Q: {question}\nA: {result.FinalAnswer}");

// You can uncomment the line below to see the planner's process for completing the request.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ internal static class ChatHistoryExtensions
/// <summary>
/// Returns the number of tokens in the chat history.
/// </summary>
// <param name="chatHistory">The chat history.</param>
// <param name="additionalMessage">An additional message to include in the token count.</param>
// <param name="skipStart">The index to start skipping messages.</param>
// <param name="skipCount">The number of messages to skip.</param>
// <param name="tokenCounter">The token counter to use.</param>
/// <param name="chatHistory">The chat history.</param>
/// <param name="additionalMessage">An additional message to include in the token count.</param>
/// <param name="skipStart">The index to start skipping messages.</param>
/// <param name="skipCount">The number of messages to skip.</param>
/// <param name="tokenCounter">The token counter to use.</param>
internal static int GetTokenCount(this ChatHistory chatHistory, string? additionalMessage = null, int skipStart = 0, int skipCount = 0, TextChunker.TokenCounter? tokenCounter = null)
{
return tokenCounter is null ?
Expand Down
7 changes: 6 additions & 1 deletion dotnet/src/Planners/Planners.OpenAI/Planners.OpenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
</PropertyGroup>

<ItemGroup>
<EmbeddedResource Include="Stepwise\InitialPlanPrompt.txt">
<None Remove="Stepwise\GeneratePlan.yaml" />
</ItemGroup>

<ItemGroup>
<EmbeddedResource Include="Stepwise\GeneratePlan.yaml">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</EmbeddedResource>
<EmbeddedResource Include="Stepwise\StepPrompt.txt">
Expand All @@ -29,6 +33,7 @@
<ItemGroup>
<ProjectReference Include="..\..\Connectors\Connectors.AI.OpenAI\Connectors.AI.OpenAI.csproj" />
<ProjectReference Include="..\..\Functions\Functions.OpenAPI\Functions.OpenAPI.csproj" />
<ProjectReference Include="..\..\Functions\Functions.Yaml\Functions.Yaml.csproj" />
<ProjectReference Include="..\..\SemanticKernel.Abstractions\SemanticKernel.Abstractions.csproj" />
<ProjectReference Include="..\..\SemanticKernel.Core\SemanticKernel.Core.csproj" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,44 @@ public sealed class FunctionCallingStepwisePlanner
/// <summary>
/// Initialize a new instance of the <see cref="FunctionCallingStepwisePlanner"/> class.
/// </summary>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="config">The planner configuration.</param>
public FunctionCallingStepwisePlanner(
Kernel kernel,
FunctionCallingStepwisePlannerConfig? config = null)
{
Verify.NotNull(kernel);
this._kernel = kernel;
this._chatCompletionService = kernel.GetService<IChatCompletionService>();

ILoggerFactory loggerFactory = kernel.LoggerFactory;

// Initialize prompt renderer
this._promptTemplateFactory = new KernelPromptTemplateFactory(loggerFactory);

// Set up Config with default values and excluded plugins
this.Config = config ?? new();
this.Config.ExcludedPlugins.Add(RestrictedPluginName);

this._initialPlanPrompt = this.Config.GetPromptTemplate?.Invoke() ?? EmbeddedResource.Read("Stepwise.InitialPlanPrompt.txt");
this._generatePlanYaml = this.Config.GetPromptTemplate?.Invoke() ?? EmbeddedResource.Read("Stepwise.GeneratePlan.yaml");
this._stepPrompt = this.Config.GetStepPromptTemplate?.Invoke() ?? EmbeddedResource.Read("Stepwise.StepPrompt.txt");

// Create context and logger
this._logger = loggerFactory.CreateLogger(this.GetType());
this.Config.ExcludedPlugins.Add(StepwisePlannerPluginName);
}

/// <summary>
/// Execute a plan
/// </summary>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="question">The question to answer</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Result containing the model's response message and chat history.</returns>
public async Task<FunctionCallingStepwisePlannerResult> ExecuteAsync(
Kernel kernel,
string question,
CancellationToken cancellationToken = default)
{
Verify.NotNullOrWhiteSpace(question);
Verify.NotNull(kernel);
IChatCompletionService chatCompletion = kernel.GetService<IChatCompletionService>();
ILoggerFactory loggerFactory = kernel.LoggerFactory;
ILogger logger = loggerFactory.CreateLogger(this.GetType());
var promptTemplateFactory = new KernelPromptTemplateFactory(loggerFactory);
var stepExecutionSettings = this.Config.ExecutionSettings ?? new OpenAIPromptExecutionSettings();

// Add the final answer function
this._kernel.ImportPluginFromObject<UserInteraction>();
// Clone the kernel so that we can add planner-specific plugins without affecting the original kernel instance
var clonedKernel = kernel.Clone();
clonedKernel.ImportPluginFromObject<UserInteraction>();

// Request completion for initial plan
var chatHistoryForPlan = await this.BuildChatHistoryForInitialPlanAsync(question, cancellationToken).ConfigureAwait(false);
string initialPlan = await this._chatCompletionService.GetChatMessageContentAsync(chatHistoryForPlan, null /* execution settings */, this._kernel, cancellationToken).ConfigureAwait(false);
// Create and invoke a kernel function to generate the initial plan
var initialPlan = await this.GeneratePlanAsync(question, clonedKernel, logger, cancellationToken).ConfigureAwait(false);

var chatHistoryForSteps = await this.BuildChatHistoryForStepAsync(question, initialPlan, cancellationToken).ConfigureAwait(false);
var chatHistoryForSteps = await this.BuildChatHistoryForStepAsync(question, initialPlan, clonedKernel, chatCompletion, promptTemplateFactory, cancellationToken).ConfigureAwait(false);

for (int i = 0; i < this.Config.MaxIterations; i++)
{
Expand All @@ -78,11 +70,11 @@ public async Task<FunctionCallingStepwisePlannerResult> ExecuteAsync(

// For each step, request another completion to select a function for that step
chatHistoryForSteps.AddUserMessage(StepwiseUserMessage);
var chatMessage = (OpenAIChatMessageContent)await this.GetCompletionWithFunctionsAsync(chatHistoryForSteps, cancellationToken).ConfigureAwait(false);
chatHistoryForSteps.AddAssistantMessage(chatMessage!.Content);
var chatResult = await this.GetCompletionWithFunctionsAsync(chatHistoryForSteps, clonedKernel, chatCompletion, stepExecutionSettings, logger, cancellationToken).ConfigureAwait(false);
chatHistoryForSteps.AddAssistantMessage(chatResult);

// Check for function response
if (!this.TryGetFunctionResponse(chatMessage, out OpenAIFunctionResponse? functionResponse, out string? functionResponseError))
if (!this.TryGetFunctionResponse(chatResult, out OpenAIFunctionResponse? functionResponse, out string? functionResponseError))
{
// No function response found. Either AI returned a chat message, or something went wrong when parsing the function.
// Log the error (if applicable), then let the planner continue.
Expand Down Expand Up @@ -114,12 +106,12 @@ public async Task<FunctionCallingStepwisePlannerResult> ExecuteAsync(
}

// Look up function in kernel
if (this._kernel.Plugins.TryGetFunctionAndArguments(functionResponse, out KernelFunction? pluginFunction, out KernelArguments? arguments))
if (clonedKernel.Plugins.TryGetFunctionAndArguments(functionResponse, out KernelFunction? pluginFunction, out KernelArguments? arguments))
{
try
{
// Execute function and add to result to chat history
var result = (await this._kernel.InvokeAsync(pluginFunction, arguments, cancellationToken).ConfigureAwait(false)).GetValue<object>();
var result = (await clonedKernel.InvokeAsync(pluginFunction, arguments, cancellationToken).ConfigureAwait(false)).GetValue<object>();
chatHistoryForSteps.AddFunctionMessage(ParseObjectAsString(result), functionResponse.FullyQualifiedName);
}
catch (KernelException)
Expand All @@ -145,67 +137,71 @@ public async Task<FunctionCallingStepwisePlannerResult> ExecuteAsync(
#region private

private async Task<ChatMessageContent> GetCompletionWithFunctionsAsync(
ChatHistory chatHistory,
CancellationToken cancellationToken)
ChatHistory chatHistory,
Kernel kernel,
IChatCompletionService chatCompletion,
OpenAIPromptExecutionSettings openAIExecutionSettings,
ILogger logger,
CancellationToken cancellationToken)
{
var executionSettings = this.PrepareOpenAIExecutionSettingsWithFunctions();
return await this._chatCompletionService.GetChatMessageContentAsync(chatHistory, executionSettings, this._kernel, cancellationToken).ConfigureAwait(false);
}
openAIExecutionSettings.FunctionCallBehavior = FunctionCallBehavior.EnableKernelFunctions;

private async Task<string> GetFunctionsManualAsync(CancellationToken cancellationToken)
{
return await this._kernel.Plugins.GetJsonSchemaFunctionsManualAsync(this.Config, null, this._logger, false, cancellationToken).ConfigureAwait(false);
await this.ValidateTokenCountAsync(chatHistory, kernel, logger, openAIExecutionSettings, cancellationToken).ConfigureAwait(false);
return await chatCompletion.GetChatMessageContentAsync(chatHistory, openAIExecutionSettings, kernel, cancellationToken).ConfigureAwait(false);
}

private OpenAIPromptExecutionSettings PrepareOpenAIExecutionSettingsWithFunctions()
private async Task<string> GetFunctionsManualAsync(Kernel kernel, ILogger logger, CancellationToken cancellationToken)
{
var executionSettings = this.Config.ModelSettings ?? new OpenAIPromptExecutionSettings();
executionSettings.FunctionCallBehavior = FunctionCallBehavior.EnableKernelFunctions;
return executionSettings;
return await kernel.Plugins.GetJsonSchemaFunctionsManualAsync(this.Config, null, logger, false, cancellationToken).ConfigureAwait(false);
}

private async Task<ChatHistory> BuildChatHistoryForInitialPlanAsync(
string goal,
CancellationToken cancellationToken)
// Create and invoke a kernel function to generate the initial plan
private async Task<string> GeneratePlanAsync(string question, Kernel kernel, ILogger logger, CancellationToken cancellationToken)
{
var chatHistory = new ChatHistory();

var arguments = new KernelArguments();
string functionsManual = await this.GetFunctionsManualAsync(cancellationToken).ConfigureAwait(false);
arguments[AvailableFunctionsKey] = functionsManual;
string systemMessage = await this._promptTemplateFactory.Create(new PromptTemplateConfig(this._initialPlanPrompt)).RenderAsync(this._kernel, arguments, cancellationToken).ConfigureAwait(false);

chatHistory.AddSystemMessage(systemMessage);
chatHistory.AddUserMessage(goal);

return chatHistory;
var generatePlanFunction = kernel.CreateFunctionFromPromptYaml(this._generatePlanYaml, pluginName: StepwisePlannerPluginName);
string functionsManual = await this.GetFunctionsManualAsync(kernel, logger, cancellationToken).ConfigureAwait(false);
var generatePlanArgs = new KernelArguments
{
[AvailableFunctionsKey] = functionsManual,
[GoalKey] = question
};
var generatePlanResult = await kernel.InvokeAsync(generatePlanFunction, generatePlanArgs, cancellationToken).ConfigureAwait(false);
return generatePlanResult.GetValue<string>() ?? throw new KernelException("Failed get a completion for the plan.");
}

private async Task<ChatHistory> BuildChatHistoryForStepAsync(
string goal,
string initialPlan,
Kernel kernel,
IChatCompletionService chatCompletion,
KernelPromptTemplateFactory promptTemplateFactory,
CancellationToken cancellationToken)
{
var chatHistory = new ChatHistory();

// Add system message with context about the initial goal/plan
var arguments = new KernelArguments();
arguments[GoalKey] = goal;
arguments[InitialPlanKey] = initialPlan;
var systemMessage = await this._promptTemplateFactory.Create(new PromptTemplateConfig(this._stepPrompt)).RenderAsync(this._kernel, arguments, cancellationToken).ConfigureAwait(false);
var arguments = new KernelArguments
{
[GoalKey] = goal,
[InitialPlanKey] = initialPlan
};
var systemMessage = await promptTemplateFactory.Create(new PromptTemplateConfig(this._stepPrompt)).RenderAsync(kernel, arguments, cancellationToken).ConfigureAwait(false);

chatHistory.AddSystemMessage(systemMessage);

return chatHistory;
}

private bool TryGetFunctionResponse(OpenAIChatMessageContent chatMessage, [NotNullWhen(true)] out OpenAIFunctionResponse? functionResponse, out string? errorMessage)
private bool TryGetFunctionResponse(ChatMessageContent chatMessage, [NotNullWhen(true)] out OpenAIFunctionResponse? functionResponse, out string? errorMessage)
{
OpenAIChatMessageContent? openAiChatMessage = chatMessage as OpenAIChatMessageContent;
Verify.NotNull(openAiChatMessage, nameof(openAiChatMessage));

functionResponse = null;
errorMessage = null;
try
{
functionResponse = chatMessage.GetOpenAIFunctionResponse();
functionResponse = openAiChatMessage.GetOpenAIFunctionResponse();
}
catch (JsonException)
{
Expand Down Expand Up @@ -266,35 +262,47 @@ private static string ParseObjectAsString(object? valueObj)
return resultStr;
}

private async Task ValidateTokenCountAsync(
ChatHistory chatHistory,
Kernel kernel,
ILogger logger,
OpenAIPromptExecutionSettings openAIExecutionSettings,
CancellationToken cancellationToken)
{
string functionManual = string.Empty;

// If using functions, get the functions manual to include in token count estimate
if (openAIExecutionSettings.FunctionCallBehavior == FunctionCallBehavior.EnableKernelFunctions)
{
functionManual = await this.GetFunctionsManualAsync(kernel, logger, cancellationToken).ConfigureAwait(false);
}

var tokenCount = chatHistory.GetTokenCount(additionalMessage: functionManual);
if (tokenCount >= this.Config.MaxPromptTokens)
{
throw new KernelException("ChatHistory is too long to get a completion. Try reducing the available functions.");
}
}

/// <summary>
/// The configuration for the StepwisePlanner
/// </summary>
private FunctionCallingStepwisePlannerConfig Config { get; }

// Context used to access the list of functions in the kernel
private readonly Kernel _kernel;
private readonly IChatCompletionService _chatCompletionService;
private readonly ILogger? _logger;

/// <summary>
/// The prompt (system message) used to generate the initial set of steps to perform.
/// The prompt YAML for generating the initial stepwise plan.
/// </summary>
private readonly string _initialPlanPrompt;
private readonly string _generatePlanYaml;

/// <summary>
/// The prompt (system message) for performing the steps.
/// </summary>
private readonly string _stepPrompt;

/// <summary>
/// The prompt renderer to use for the system step
/// </summary>
private readonly KernelPromptTemplateFactory _promptTemplateFactory;

/// <summary>
/// The name to use when creating semantic functions that are restricted from plan creation
/// </summary>
private const string RestrictedPluginName = "OpenAIFunctionsStepwisePlanner_Excluded";
private const string StepwisePlannerPluginName = "StepwisePlanner_Excluded";

/// <summary>
/// The user message to add to the chat history for each step of the plan.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ public FunctionCallingStepwisePlannerConfig()
this.MaxTokens = 4000;
}

/// <summary>
/// The ratio of tokens to allocate to the completion request. (prompt / (prompt + completion))
/// </summary>
public double MaxTokensRatio { get; set; } = 0.1;

internal int MaxCompletionTokens { get { return (int)(this.MaxTokens * this.MaxTokensRatio); } }

internal int MaxPromptTokens { get { return (int)(this.MaxTokens * (1 - this.MaxTokensRatio)); } }

/// <summary>
/// Delegate to get the prompt template string for the step execution phase.
/// </summary>
Expand All @@ -34,7 +43,7 @@ public FunctionCallingStepwisePlannerConfig()
public int MinIterationTimeMs { get; set; }

/// <summary>
/// The configuration to use for the prompt template.
/// The prompt execution settings to use for the step execution phase.
/// </summary>
public OpenAIPromptExecutionSettings? ModelSettings { get; set; }
public OpenAIPromptExecutionSettings? ExecutionSettings { get; set; }
}
Loading

0 comments on commit cb12931

Please sign in to comment.