|
| 1 | +from dotenv import load_dotenv |
| 2 | + |
| 3 | +import requests |
| 4 | +from datetime import datetime |
| 5 | +from os import environ |
| 6 | + |
| 7 | +from autogen import AssistantAgent, UserProxyAgent, config_list_from_json |
| 8 | +import autogen |
| 9 | + |
| 10 | +import replicate |
| 11 | + |
| 12 | +load_dotenv() |
| 13 | + |
| 14 | +OUTPUT_FOLDER = environ["OUTPUT_FOLDER"] |
| 15 | +REPLICATE_API_TOKEN = environ["REPLICATE_API_TOKEN"] |
| 16 | + |
| 17 | +autogen_config_list = config_list_from_json( |
| 18 | + env_or_file="OAI_CONFIG_LIST", |
| 19 | + # filter_dict={ |
| 20 | + # # Function calling with GPT 3.5 |
| 21 | + # "model": ["gpt-3.5-turbo"], |
| 22 | + # } |
| 23 | +) |
| 24 | + |
| 25 | +# Create llm config for group chat manager |
| 26 | +# - GroupChatManager is not allowed to make function/tool calls. |
| 27 | +autogen_llm_config = { |
| 28 | + "config_list": autogen_config_list |
| 29 | +} |
| 30 | + |
| 31 | +# Create llm config for assistants |
| 32 | +autogen_llm_config_assistant = { |
| 33 | + "functions": [ |
| 34 | + { |
| 35 | + "name": "text_to_image_generation", |
| 36 | + "description": "use latest AI model to generate image based on a prompt, return the file path of image generated", |
| 37 | + "parameters": { |
| 38 | + "type": "object", |
| 39 | + "properties": { |
| 40 | + "prompt": { |
| 41 | + "type": "string", |
| 42 | + "description": "a great text to image prompt that describe the image", |
| 43 | + } |
| 44 | + }, |
| 45 | + "required": ["prompt"], |
| 46 | + }, |
| 47 | + }, |
| 48 | + { |
| 49 | + "name": "image_review", |
| 50 | + "description": "review & critique the AI generated image based on original prompt, decide how can images & prompt can be improved", |
| 51 | + "parameters": { |
| 52 | + "type": "object", |
| 53 | + "properties": { |
| 54 | + "prompt": { |
| 55 | + "type": "string", |
| 56 | + "description": "the original prompt used to generate the image", |
| 57 | + }, |
| 58 | + "image_file_path": { |
| 59 | + "type": "string", |
| 60 | + "description": "the image file path, make sure including the full file path & file extension", |
| 61 | + } |
| 62 | + }, |
| 63 | + "required": ["prompt", "image_file_path"], |
| 64 | + }, |
| 65 | + }, |
| 66 | + ], |
| 67 | + "config_list": autogen_config_list, |
| 68 | +} |
| 69 | + |
| 70 | +# function to use stability-ai model to generate image |
| 71 | +def text_to_image_generation(prompt: str) -> str: |
| 72 | + output = replicate.run( |
| 73 | + "stability-ai/sdxl:c221b2b8ef527988fb59bf24a8b97c4561f1c671f73bd389f866bfb27c061316", |
| 74 | + input={ |
| 75 | + "prompt": prompt |
| 76 | + } |
| 77 | + ) |
| 78 | + |
| 79 | + if output and len(output) > 0: |
| 80 | + # Get the image URL from the output |
| 81 | + image_url = output[0] |
| 82 | + print(f"generated image for {prompt}: {image_url}") |
| 83 | + |
| 84 | + # Download the image and save it with a filename based on the prompt and current time |
| 85 | + current_time = datetime.now().strftime("%Y%m%d%H%M%S") |
| 86 | + shortened_prompt = prompt[:50] |
| 87 | + image_file_path = f"{OUTPUT_FOLDER}/{shortened_prompt}_{current_time}.png" |
| 88 | + |
| 89 | + response = requests.get(image_url) |
| 90 | + if response.status_code == 200: |
| 91 | + with open(image_file_path, "wb") as file: |
| 92 | + file.write(response.content) |
| 93 | + print(f"Image saved as '{image_file_path}'") |
| 94 | + return image_file_path |
| 95 | + else: |
| 96 | + raise Exception("Failed to download and save the image.") |
| 97 | + else: |
| 98 | + raise Exception("Failed to generate the image.") |
| 99 | + |
| 100 | + |
| 101 | +def img_review(image_file_path: str, prompt: str): |
| 102 | + output = replicate.run( |
| 103 | + "yorickvp/llava-13b:6bc1c7bb0d2a34e413301fee8f7cc728d2d4e75bfab186aa995f63292bda92fc", |
| 104 | + input={ |
| 105 | + "image": open(image_file_path, "rb"), |
| 106 | + "prompt": f"What is happening in the image? From scale 1 to 10, decide how similar the image is to the text prompt {prompt}?", |
| 107 | + } |
| 108 | + ) |
| 109 | + |
| 110 | + result = "" |
| 111 | + for item in output: |
| 112 | + result += item |
| 113 | + |
| 114 | + print("CRITIC : ", result) |
| 115 | + |
| 116 | + return result |
| 117 | + |
| 118 | + |
| 119 | +# Create assistant agent |
| 120 | +img_gen_assistant = AssistantAgent( |
| 121 | + name="text_to_img_prompt_expert", |
| 122 | + system_message="You are a text to image AI model expert, you will use text_to_image_generation function to generate image with prompt provided, and also improve prompt based on feedback provided until it is 10/10.", |
| 123 | + llm_config=autogen_llm_config_assistant, |
| 124 | + function_map={ |
| 125 | + "text_to_image_generation": text_to_image_generation |
| 126 | + } |
| 127 | +) |
| 128 | + |
| 129 | +img_critic_assistant = AssistantAgent( |
| 130 | + name="img_critic", |
| 131 | + system_message="You are an AI image critique, you will use img_review function to review the image generated by the text_to_img_prompt_expert against the original prompt, and provide feedback on how to improve the prompt.", |
| 132 | + llm_config=autogen_llm_config_assistant, |
| 133 | + function_map={ |
| 134 | + "image_review": img_review |
| 135 | + } |
| 136 | +) |
| 137 | + |
| 138 | +# Create user proxy agent |
| 139 | +user_proxy = UserProxyAgent( |
| 140 | + name="user_proxy", |
| 141 | + human_input_mode="ALWAYS", |
| 142 | +) |
| 143 | + |
| 144 | +# user_proxy = UserProxyAgent( |
| 145 | +# name="user_proxy", |
| 146 | +# human_input_mode="TERMINATE", |
| 147 | +# max_consecutive_auto_reply=10, |
| 148 | +# is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"), |
| 149 | +# code_execution_config={ |
| 150 | +# "work_dir": "web", |
| 151 | +# "use_docker": False, |
| 152 | +# }, # Please set use_docker=True if docker is available to run the generated code. Using docker is safer than running the generated code directly. |
| 153 | +# llm_config=autogen_llm_config, |
| 154 | +# system_message="""Reply TERMINATE if the task has been solved at full satisfaction. |
| 155 | +# Otherwise, reply CONTINUE, or the reason why the task is not solved yet.""", |
| 156 | +# ) |
| 157 | + |
| 158 | +# Create groupchat |
| 159 | +groupchat = autogen.GroupChat( |
| 160 | + agents=[user_proxy, img_gen_assistant, img_critic_assistant], messages=[], max_round=50,) |
| 161 | + |
| 162 | +manager = autogen.GroupChatManager( |
| 163 | + groupchat=groupchat, |
| 164 | + llm_config=autogen_llm_config) |
| 165 | + |
| 166 | +message = "A realistic image of a cute rabbit wearing sunglasses confidently driving a shiny red sports car on a sunny day with a scenic road in the background." |
| 167 | +# message = "In Houston at 2pm, Sunny sky with few clouds. Current Temperature at 37" |
| 168 | + |
| 169 | +# text_to_image_generation(message) |
| 170 | + |
| 171 | +# img_review('./output/A realistic image of a cute rabbit wearing sunglas_20240218135543.png', message) |
| 172 | + |
| 173 | +# # Start the conversation |
| 174 | +user_proxy.initiate_chat( |
| 175 | + manager, message=message) |
0 commit comments