-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
130 lines (102 loc) · 4.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import sys
import json
from flask import Flask, render_template, request, redirect, url_for
import psycopg2
import torch
#from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import random
# Load CLIP model and processor
#clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
#clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
app = Flask(__name__)
# List of example prompts
example_prompts = [
"Upside Down dinosaur",
"Rabbit are cooking carrets",
"flaying pandas",
"monkey man",
"elephants with wings"
]
import psycopg2
try:
# Connect to the database
conn = psycopg2.connect(database="paint", user="technocrat", host="localhost", port="5432", password='1234')
# Check if the connection is successful
if conn:
print("Database connection established successfully")
else:
print("Failed to connect to the database")
# Close the connection
conn.close()
except (Exception, psycopg2.DatabaseError) as error:
print("Error connecting to the database:", error)
# Function to generate a random prompt
def generate_random_prompt():
return random.choice(example_prompts)
def preprocess_image(image_path):
image = Image.open(image_path)
# Resize image to match CLIP model input size
resized_image = image.resize((224, 224))
return resized_image
def compute_similarity(drawing_image, prompt):
# Preprocess the drawing image
processed_drawing = clip_processor(text=prompt, images=drawing_image, return_tensors="pt", padding=True)
# Forward pass through CLIP model
with torch.no_grad():
outputs = clip_model(**processed_drawing)
# Retrieve embeddings
drawing_embedding = outputs.last_hidden_state[:, 0, :]
# Compute similarity score with prompt
# You need to encode the prompt using the CLIP processor as well
prompt_encoding = clip_processor.encode_text(prompt)
similarity_score = torch.nn.functional.cosine_similarity(drawing_embedding, prompt_encoding).item()
return similarity_score
@app.route('/', methods=['GET', 'POST'])
def paintapp():
if request.method == 'GET':
# Generate a random prompt
prompt = generate_random_prompt()
return render_template("paint.html", prompt=prompt)
if request.method == 'POST':
filename = request.form['save_fname']
data = request.form['save_cdata']
canvas_image = request.form['save_image']
conn = psycopg2.connect(database="paint", user = "technocrat", host="localhost", password='1234', port='5432')
cur = conn.cursor()
cur.execute("INSERT INTO files (name, canvas_image) VALUES (%s, %s)", [filename, canvas_image])
conn.commit()
conn.close()
return redirect(url_for('save'))
@app.route('/save', methods=['GET', 'POST'])
def save():
conn = psycopg2.connect(database="paint", user="technocrat", host="localhost", port='5432', password='1234')
cur = conn.cursor()
cur.execute("SELECT id, name, canvas_image from files")
files = cur.fetchall()
conn.close()
return render_template("save.html", files = files )
@app.route('/search', methods=['GET', 'POST'])
def search():
if request.method == 'GET':
return render_template("search.html")
if request.method == 'POST':
filename = request.form['fname']
conn = psycopg2.connect(database="paint", user="technocrat", host="localhost", port='5432', password='1234')
cur = conn.cursor()
cur.execute("select id, name, canvas_image from files")
files = cur.fetchall()
conn.close()
return render_template("search.html", files=files, filename=filename)
@app.route('/compute_similarity', methods=['POST'])
def compute_and_display_similarity():
drawing_image = request.files['drawing']
prompt = request.form['prompt']
# Preprocess image
processed_image = preprocess_image(drawing_image)
# Compute similarity score
similarity_score = compute_similarity(processed_image, prompt)
return render_template("similarity.html", similarity_score=similarity_score)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=3000)