-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ability to fine tune custom model on conversable agents (#1787)
* uAbility to update_model on conversable agents * formatting * formatting * move code from conversable agent into samples/tools and add testing and README * forgot install step * fix * leave core lib unchanged and move everything to samples/tools * remove skip openai --------- Co-authored-by: Eric Zhu <[email protected]>
- Loading branch information
Showing
5 changed files
with
445 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: SamplesToolsTests | ||
|
||
on: | ||
pull_request: | ||
branches: ["main"] | ||
paths: | ||
- "autogen/**" | ||
- "samples/tools/**" | ||
- ".github/workflows/samples-tools-tests.yml" | ||
- "setup.py" | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} | ||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} | ||
permissions: {} | ||
jobs: | ||
SamplesToolsFineTuningTests: | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: [ubuntu-latest, macos-latest] | ||
python-version: ["3.9", "3.10", "3.11"] | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install packages and dependencies for all tests | ||
run: | | ||
python -m pip install --upgrade pip wheel | ||
pip install -e . | ||
pip install pytest | ||
- name: Set AUTOGEN_USE_DOCKER based on OS | ||
shell: bash | ||
run: | | ||
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then | ||
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV | ||
fi | ||
- name: Test finetuning tools | ||
run: | | ||
pytest samples/tools/finetuning/tests/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Tools for fine-tuning the local models that power agents | ||
|
||
This directory aims to contain tools for fine-tuning the local models that power agents. | ||
|
||
## Fine tune a custom model client | ||
|
||
AutoGen supports the use of custom models to power agents [see blog post here](https://microsoft.github.io/autogen/blog/2024/01/26/Custom-Models). This directory contains a tool to provide feedback to that model, that can be used to fine-tune the model. | ||
|
||
The creator of the Custom Model Client will have to decide what kind of data is going to be fed back and how it will be used to fine-tune the model. This tool is designed to be flexible and allow for a wide variety of feedback mechanisms. | ||
|
||
Custom Model Client will have follow the protocol client defined in `update_model.py` `UpdateableModelClient` which is a subclass of `ModelClient` and adds the following method: | ||
|
||
```python | ||
def update_model( | ||
self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any | ||
) -> Dict[str, Any]: | ||
"""Optional method to learn from the preference data, if the model supports learning. Can be omitted. | ||
Learn from the preference data. | ||
Args: | ||
preference_data: The preference data. | ||
inference_messages: The messages that were used during inference between the agent that is being updated and another agent. | ||
**kwargs: other arguments. | ||
Returns: | ||
Dict of learning stats. | ||
""" | ||
``` | ||
|
||
The function provided in the file `update_model.py` is called by passing these arguments: | ||
|
||
- the agent whose model is to be updated | ||
- the preference data | ||
- the agent whose conversation is being used to provide the inference messages | ||
|
||
The function will find the conversation thread that occurred between the "update agent" and the "other agent", and call the `update_model` method of the model client. It will return a dictionary containing the update stats, inference messages, and preference data: | ||
|
||
```python | ||
{ | ||
"update_stats": <the dictionary returned by the custom model client implementation>, | ||
"inference_messages": <message used for inference>, | ||
"preference_data": <the preference data passed in when update_model was called> | ||
} | ||
``` | ||
|
||
**NOTES**: | ||
|
||
`inference_messages` will contain messages that were passed into the custom model client when `create` was called and a response was needed from the model. It is up to the author of the custom model client to decide which parts of the conversation are needed and how to use this data to fine-tune the model. | ||
|
||
If a conversation has been long-running before `update_model` is called, then the `inference_messages` will contain a conversation thread that was used for multiple inference steps. It is again up to the author of the custom model client to decide which parts of the conversation correspond to the preference data and how to use this data to fine-tune the model. | ||
|
||
An example of how to use this tool is shown below: | ||
|
||
```python | ||
from finetuning.update_model import update_model | ||
|
||
assistant = AssistantAgent( | ||
"assistant", | ||
system_message="You are a helpful assistant.", | ||
human_input_mode="NEVER", | ||
llm_config={ | ||
"config_list": [<the config list containing the custom model>], | ||
}, | ||
) | ||
|
||
assistant.register_model_client(model_client_cls=<TheCustomModelClientClass>) | ||
|
||
user_proxy = UserProxyAgent( | ||
"user_proxy", | ||
human_input_mode="NEVER", | ||
max_consecutive_auto_reply=1, | ||
code_execution_config=False, | ||
llm_config=False, | ||
) | ||
|
||
res = user_proxy.initiate_chat(assistant, message="the message") | ||
response_content = res.summary | ||
|
||
# Evaluate the summary here and provide feedback. Pretending I am going to perform DPO on the response. | ||
|
||
# preference_data will be passed on as-is to the custom model client's update_model implementation | ||
# so it should be in the format that the custom model client expects and is completely up to the author of the custom model client | ||
preference_data = [("this is what the response should have been like", response_content)] | ||
|
||
update_model_stats = update_model(assistant, preference_data, user_proxy) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .update_model import update_model | ||
|
||
__all__ = ["update_model"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from autogen import ConversableAgent, Agent, OpenAIWrapper, ModelClient | ||
from typing import Any, Dict, List, Protocol | ||
|
||
|
||
class UpdateableModelClient(ModelClient, Protocol): | ||
def update_model( | ||
self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any | ||
) -> Dict[str, Any]: | ||
"""Optional method to learn from the preference data, if the model supports learning. Can be omitted. | ||
Learn from the preference data. | ||
Args: | ||
preference_data: The preference data. | ||
inference_messages: The messages used for inference. | ||
**kwargs: other arguments. | ||
Returns: | ||
Dict of learning stats. | ||
""" | ||
... # pragma: no cover | ||
|
||
|
||
def _client_wrapper_update_model( | ||
oai_wrapper_client: OpenAIWrapper, | ||
preference_data: List[Any], | ||
inference_messages: List[Dict[str, Any]], | ||
**kwargs: Any, | ||
) -> Dict[str, Any]: | ||
"""Learn from the preference data. | ||
update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages. | ||
Args: | ||
oai_wrapper_client: The OpenAIWrapper client. | ||
preference_data: The preference data. | ||
inference_messages: The messages that were used during inference between the agent that is being updated and another agent. | ||
**kwargs: other arguments. | ||
Returns: | ||
Learning stats. | ||
Raises: | ||
ValueError: If multiple model clients are registered. | ||
NotImplementedError: If update_model is not implemented for the client. | ||
""" | ||
|
||
clients = oai_wrapper_client._clients | ||
|
||
if len(clients) != 1: | ||
raise ValueError("update_model is not supported for multiple model clients.") | ||
client = clients[0] | ||
if hasattr(client, "update_model") and callable(getattr(client, "update_model")): | ||
return client.update_model(preference_data, inference_messages, **kwargs) | ||
else: | ||
raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.") | ||
|
||
|
||
def update_model( | ||
update_agent: ConversableAgent, preference_data: List[Dict[str, Any]], other_agent: Agent, **kwargs | ||
) -> Dict[str, Any]: | ||
"""Update the model using the preference data and the conversation history. | ||
Args: | ||
update_agent (ConversableAgent): the agent whose model will be updated. | ||
preference_data (List[Dict]): a list of dictionaries containing the preference data. | ||
other_agent (Agent): the agent whose conversation history will be used to update the model. | ||
**kwargs: additional keyword arguments for the update model function. | ||
Returns: | ||
Dict: a dictionary containing the update stats, inference_messages, and preference data, like so: | ||
{ | ||
"update_stats": update_model_stats, | ||
"inference_messages": inference_messages, | ||
"preference_data": preference_data | ||
} | ||
Raises: | ||
ValueError: If no OpenAIWrapper client is found. | ||
ValueError: If multiple model clients are registered. | ||
NotImplementedError: If update_model is not implemented for the underlying client. | ||
""" | ||
if update_agent.client is None: | ||
raise ValueError("No OpenAIWrapper client is found.") | ||
inference_messages = update_agent._oai_messages[other_agent] | ||
update_model_stats = _client_wrapper_update_model( | ||
update_agent.client, preference_data, inference_messages, **kwargs | ||
) | ||
return { | ||
"update_stats": update_model_stats, | ||
"inference_messages": inference_messages, | ||
"preference_data": preference_data, | ||
} |
Oops, something went wrong.