Files
dalle/generations.py
2024-05-09 19:10:16 -04:00

203 lines
8.2 KiB
Python

import argparse
import os
import sys
import webbrowser
import tkinter as tk
from tkinter import messagebox
import openai
from openai import OpenAI
def api(prompt, model, n, quality, response_format, size, style, user):
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
client = OpenAI(api_key=OPENAI_API_KEY)
response = client.images.generate(
prompt=prompt,
model=model,
n=n,
quality=quality,
response_format=response_format,
size=size,
style=style,
user=user
)
return response
class CLI:
def run_cli():
parser = argparse.ArgumentParser(description='Generate images using OpenAI.')
parser.add_argument('-p', '--prompt', type=str, help='A text description of the desired image.')
parser.add_argument('-m', '--model', type=str, default='dall-e-2', choices=['dall-e-2', 'dall-e-3'], help='The model to use for image generation.')
parser.add_argument('-n', type=int, default=1, choices=range(1, 11), help='The number of images to generate.')
parser.add_argument('-q', '--quality', type=str, default='standard', choices=['standard', 'hd'], help='The quality of the image.')
parser.add_argument('-rf', '--response_format', type=str, default='url', choices=['url', 'b64_json'], help='The format in which the images are returned.')
parser.add_argument('-s', '--size', type=str, default='1024x1024', choices=['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792'], help='The size of the generated images.')
parser.add_argument('-st', '--style', type=str, default='vivid', choices=['vivid', 'natural'], help='The style of the generated images.')
parser.add_argument('-u', '--user', type=str, default='', help='A unique identifier representing your end-user.')
args = parser.parse_args()
result = api(
args.prompt,
args.model,
args.n,
args.quality,
args.response_format,
args.size,
args.style,
args.user
)
print(result)
class GUI:
def __init__(self, root):
self.root = root
self.setup_gui()
def setup_gui(self):
self.root.title("OpenAI Image Generator")
self.root.grid_columnconfigure(1, weight=1)
self.prompt_text = self.create_text("Prompt:", 0)
self.model_var, self.model_option_menu = self.create_option_menu("Model:", "dall-e-2", ["dall-e-2", "dall-e-3"], row=1)
self.n_spinbox = self.create_spinbox("Number of Images:", 1, 10, "readonly", 2)
self.quality_var, self.quality_option_menu = self.create_option_menu("Quality:", "standard", ["standard", "hd"], row=3)
self.response_format_var, self.response_format_option_menu = self.create_option_menu("Response Format:", "url", ["url", "b64_json"], row=4)
self.size_var, self.size_option_menu = self.create_option_menu("Size:", "1024x1024", ["256x256", "512x512", "1024x1024", "1024x1792", "1792x1024"], row=5)
self.style_var, self.style_option_menu = self.create_option_menu("Style:", "vivid", ["vivid", "natural"], row=6)
self.user_entry = self.create_entry("User:", 7)
self.generate_button = self.create_button("Generate Image:", self.generate_images, 8)
self.prompt_text.bind("<KeyRelease>", self.check_prompt)
self.model_var.trace_add("write", self.update_gui_based_on_model)
self.check_prompt()
self.update_gui_based_on_model()
def create_text(self, text, row):
tk.Label(self.root, text=text).grid(row=row, column=0, sticky='e')
text = tk.Text(self.root, height=10, width=50)
text.grid(row=row, column=1, pady=10, sticky='nsew', columnspan=2)
return text
def create_option_menu(self, text, default, options, row):
tk.Label(self.root, text=text).grid(row=row, column=0, sticky='e')
var = tk.StringVar(value=default)
option_menu = tk.OptionMenu(self.root, var, *options)
option_menu.grid(row=row, column=1, pady=10, sticky='w')
return var, option_menu
def create_spinbox(self, text, from_, to, state, row):
tk.Label(self.root, text=text).grid(row=row, column=0, sticky='e')
spinbox = tk.Spinbox(self.root, from_=from_, to=to, state=state, width=15)
spinbox.grid(row=row, column=1, pady=10, sticky='w')
return spinbox
def create_entry(self, text, row):
tk.Label(self.root, text=text).grid(row=row, column=0, sticky='e')
entry = tk.Entry(self.root)
entry.grid(row=row, column=1, pady=10, sticky='w')
return entry
def create_button(self, text, command, row):
button = tk.Button(self.root, text=text, command=command)
button.grid(row=row, column=0, pady=10, columnspan=2)
return button
def check_prompt(self, *event):
prompt_content = self.prompt_text.get("1.0", "end-1c").strip()
if prompt_content:
self.generate_button.config(state=tk.NORMAL)
else:
self.generate_button.config(state=tk.DISABLED)
def update_gui_based_on_model(self, *args):
model = self.model_var.get()
model_configs = {
"dall-e-2": {
"n_values": range(1, 11),
"quality": ("standard", ["standard"], ["hd"]),
"size": ("1024x1024", ["256x256", "512x512", "1024x1024"], ["1024x1792", "1792x1024"]),
"style": ("vivid", ["vivid"], ["natural"]),
"response_format": ("url", ["url"], ["b64_json"])
},
"dall-e-3": {
"n_values": (1,),
"quality": ("standard", ["hd"], []),
"size": ("1024x1024", ["1024x1792", "1792x1024"], ["256x256", "512x512"]),
"style": ("vivid", ["natural"], []),
"response_format": ("url", ["url"], ["b64_json"])
}
}
def apply_config(config_key, config_value):
if config_key in ["quality", "size", "style", "response_format"]:
var, enabled, disabled = config_value
self.__dict__[f"{config_key}_var"].set(var)
for option in enabled:
self.__dict__[f"{config_key}_option_menu"]['menu'].entryconfig(option, state="normal")
for option in disabled:
self.__dict__[f"{config_key}_option_menu"]['menu'].entryconfig(option, state="disabled")
elif config_key == "n_values":
self.n_spinbox.config(values=tuple(config_value))
config = model_configs.get(model)
for key, value in config.items():
apply_config(key, value)
def generate_images(self):
try:
result = api(
self.prompt_text.get("1.0", "end-1c"),
self.model_var.get(),
int(self.n_spinbox.get()),
self.quality_var.get(),
self.response_format_var.get(),
self.size_var.get(),
self.style_var.get(),
self.user_entry.get()
)
urls = []
revised_prompts = []
for item in result.data:
urls.append(item.url)
if hasattr(item, 'revised_prompt') and item.revised_prompt:
revised_prompts.append(item.revised_prompt)
for url in urls:
webbrowser.open(url)
if revised_prompts:
all_revised_prompts = "\n\n".join(revised_prompts)
messagebox.showinfo("Success", f"Image(s) generated successfully. \n\nRevised prompt(s):\n\n{all_revised_prompts}")
else:
prompt = self.prompt_text.get("1.0", "end-1c")
messagebox.showinfo("Success", f"Image(s) generated successfully.\n\nPrompt:\n\n{prompt}")
except Exception as e:
messagebox.showerror("Error", e)
def run_gui():
root = tk.Tk()
GUI(root)
root.mainloop()
def main():
if len(sys.argv) > 1:
try:
CLI.run_cli()
except Exception as e:
print(f"Failed to start CLI mode: {e}")
else:
try:
GUI.run_gui()
except Exception as e:
print(f"Failed to start GUI mode: {e}")
if __name__ == "__main__":
main()