Skip to content

Commit d59672b

Browse files
authored
Add Dalle Support (#147) (#131)
* Add Dalle Support (#147)
1 parent d1769c1 commit d59672b

File tree

5 files changed

+189
-2
lines changed

5 files changed

+189
-2
lines changed

openai/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Optional
77

88
from openai.api_resources import (
9+
DALLE,
910
Answer,
1011
Classification,
1112
Completion,
@@ -50,6 +51,7 @@
5051
"Completion",
5152
"Customer",
5253
"Edit",
54+
"DALLE",
5355
"Deployment",
5456
"Embedding",
5557
"Engine",

openai/api_resources/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from openai.api_resources.classification import Classification # noqa: F401
33
from openai.api_resources.completion import Completion # noqa: F401
44
from openai.api_resources.customer import Customer # noqa: F401
5-
from openai.api_resources.edit import Edit # noqa: F401
5+
from openai.api_resources.dalle import DALLE # noqa: F401
66
from openai.api_resources.deployment import Deployment # noqa: F401
7+
from openai.api_resources.edit import Edit # noqa: F401
78
from openai.api_resources.embedding import Embedding # noqa: F401
89
from openai.api_resources.engine import Engine # noqa: F401
910
from openai.api_resources.error_object import ErrorObject # noqa: F401

openai/api_resources/dalle.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# WARNING: This interface is considered experimental and may changed in the future without warning.
2+
from typing import Any, List
3+
4+
import openai
5+
from openai import api_requestor, util
6+
from openai.api_resources.abstract import APIResource
7+
8+
9+
class DALLE(APIResource):
10+
OBJECT_NAME = "images"
11+
12+
@classmethod
13+
def _get_url(cls, action):
14+
return cls.class_url() + f"/{action}"
15+
16+
@classmethod
17+
def generations(
18+
cls,
19+
**params,
20+
):
21+
instance = cls()
22+
return instance.request("post", cls._get_url("generations"), params)
23+
24+
@classmethod
25+
def variations(
26+
cls,
27+
image,
28+
api_key=None,
29+
api_base=None,
30+
api_type=None,
31+
api_version=None,
32+
organization=None,
33+
**params,
34+
):
35+
requestor = api_requestor.APIRequestor(
36+
api_key,
37+
api_base=api_base or openai.api_base,
38+
api_type=api_type,
39+
api_version=api_version,
40+
organization=organization,
41+
)
42+
_, api_version = cls._get_api_type_and_version(api_type, api_version)
43+
44+
url = cls._get_url("variations")
45+
46+
files: List[Any] = []
47+
for key, value in params.items():
48+
files.append((key, (None, value)))
49+
files.append(("image", ("image", image, "application/octet-stream")))
50+
51+
response, _, api_key = requestor.request("post", url, files=files)
52+
53+
return util.convert_to_openai_object(
54+
response, api_key, api_version, organization
55+
)
56+
57+
@classmethod
58+
def edits(
59+
cls,
60+
image,
61+
mask,
62+
api_key=None,
63+
api_base=None,
64+
api_type=None,
65+
api_version=None,
66+
organization=None,
67+
**params,
68+
):
69+
requestor = api_requestor.APIRequestor(
70+
api_key,
71+
api_base=api_base or openai.api_base,
72+
api_type=api_type,
73+
api_version=api_version,
74+
organization=organization,
75+
)
76+
_, api_version = cls._get_api_type_and_version(api_type, api_version)
77+
78+
url = cls._get_url("edits")
79+
80+
files: List[Any] = []
81+
for key, value in params.items():
82+
files.append((key, (None, value)))
83+
files.append(("image", ("image", image, "application/octet-stream")))
84+
files.append(("mask", ("mask", mask, "application/octet-stream")))
85+
86+
response, _, api_key = requestor.request("post", url, files=files)
87+
88+
return util.convert_to_openai_object(
89+
response, api_key, api_version, organization
90+
)

openai/cli.py

+94
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,49 @@ def list(cls, args):
229229
print(file)
230230

231231

232+
class DALLE:
233+
@classmethod
234+
def generations(cls, args):
235+
resp = openai.DALLE.generations(
236+
prompt=args.prompt,
237+
model=args.model,
238+
size=args.size,
239+
num_images=args.num_images,
240+
response_format=args.response_format,
241+
)
242+
print(resp)
243+
244+
@classmethod
245+
def variations(cls, args):
246+
with open(args.image, "rb") as file_reader:
247+
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
248+
resp = openai.DALLE.variations(
249+
image=buffer_reader,
250+
model=args.model,
251+
size=args.size,
252+
num_images=args.num_images,
253+
response_format=args.response_format,
254+
)
255+
print(resp)
256+
257+
@classmethod
258+
def edits(cls, args):
259+
with open(args.image, "rb") as file_reader:
260+
image_reader = BufferReader(file_reader.read(), desc="Upload progress")
261+
with open(args.mask, "rb") as file_reader:
262+
mask_reader = BufferReader(file_reader.read(), desc="Upload progress")
263+
resp = openai.DALLE.edits(
264+
image=image_reader,
265+
mask=mask_reader,
266+
prompt=args.prompt,
267+
model=args.model,
268+
size=args.size,
269+
num_images=args.num_images,
270+
response_format=args.response_format,
271+
)
272+
print(resp)
273+
274+
232275
class Search:
233276
@classmethod
234277
def prepare_data(cls, args, purpose):
@@ -983,6 +1026,57 @@ def help(args):
9831026
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
9841027
sub.set_defaults(func=FineTune.cancel)
9851028

1029+
# DALLE
1030+
sub = subparsers.add_parser("dalle.generations")
1031+
sub.add_argument("-m", "--model", type=str, default="image-alpha-001")
1032+
sub.add_argument("-p", "--prompt", type=str, required=True)
1033+
sub.add_argument("-n", "--num-images", type=int, default=1)
1034+
sub.add_argument(
1035+
"-s", "--size", type=str, default="1024x1024", help="Size of the output image"
1036+
)
1037+
sub.add_argument("--response-format", type=str, default="url")
1038+
sub.set_defaults(func=DALLE.generations)
1039+
1040+
sub = subparsers.add_parser("dalle.edits")
1041+
sub.add_argument("-m", "--model", type=str, default="image-alpha-001")
1042+
sub.add_argument("-p", "--prompt", type=str, required=True)
1043+
sub.add_argument("-n", "--num-images", type=int, default=1)
1044+
sub.add_argument(
1045+
"-I",
1046+
"--image",
1047+
type=str,
1048+
required=True,
1049+
help="Image to modify. Should be a local path and a PNG encoded image.",
1050+
)
1051+
sub.add_argument(
1052+
"-s", "--size", type=str, default="1024x1024", help="Size of the output image"
1053+
)
1054+
sub.add_argument("--response-format", type=str, default="url")
1055+
sub.add_argument(
1056+
"-M",
1057+
"--mask",
1058+
type=str,
1059+
required=True,
1060+
help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
1061+
)
1062+
sub.set_defaults(func=DALLE.edits)
1063+
1064+
sub = subparsers.add_parser("dalle.variations")
1065+
sub.add_argument("-m", "--model", type=str, default="image-alpha-001")
1066+
sub.add_argument("-n", "--num-images", type=int, default=1)
1067+
sub.add_argument(
1068+
"-I",
1069+
"--image",
1070+
type=str,
1071+
required=True,
1072+
help="Image to modify. Should be a local path and a PNG encoded image.",
1073+
)
1074+
sub.add_argument(
1075+
"-s", "--size", type=str, default="1024x1024", help="Size of the output image"
1076+
)
1077+
sub.add_argument("--response-format", type=str, default="url")
1078+
sub.set_defaults(func=DALLE.variations)
1079+
9861080

9871081
def wandb_register(parser):
9881082
subparsers = parser.add_subparsers(

openai/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
VERSION = "0.23.1"
1+
VERSION = "0.24.0"

0 commit comments

Comments
 (0)