-
Notifications
You must be signed in to change notification settings - Fork 0
/
quick_vilt_gui.py
107 lines (82 loc) · 3.33 KB
/
quick_vilt_gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import tkinter as tk
from tkinter import filedialog
from transformers import ViltProcessor, ViltForQuestionAnswering
import requests
from PIL import Image, ImageTk
from io import BytesIO
import os
def browse_image():
predicted_answer.set("")
global image_path
image_path = filedialog.askopenfilename(filetypes=[("Image Files", "*.png;*.jpg;*.jpeg;*.gif;*.bmp")])
entry_image.delete(0, tk.END)
entry_image.insert(0, image_path)
if not preview_var.get() and thumbnail_label:
thumbnail_label.pack_forget()
elif preview_var.get():
display_thumbnail(image_path)
def preview_checkbox_changed():
if not preview_var.get() and thumbnail_label:
thumbnail_label.pack_forget()
elif preview_var.get():
display_thumbnail(image_path)
def display_thumbnail(image_path):
global thumbnail_label
if image_path:
image = Image.open(image_path)
image.thumbnail((100, 100), Image.ANTIALIAS)
photo = ImageTk.PhotoImage(image)
if thumbnail_label:
thumbnail_label.destroy()
thumbnail_label = tk.Label(root, image=photo)
thumbnail_label.image = photo
thumbnail_label.pack(pady=10)
def load_image(image):
if os.path.exists(image):
with open(image, "rb") as file:
raw_image_data = file.read()
return Image.open(BytesIO(raw_image_data))
else:
try:
response = requests.get(image, stream=True)
return Image.open(response.raw)
except Exception as e:
raise ValueError(f"Unable to load image: {e}")
def run_prediction(event=None):
predicted_answer.set("")
image = load_image(entry_image.get()).convert("RGB")
question = entry_question.get()
print (f"Question: {question}")
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
encoding = processor(image, question, return_tensors="pt")
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
answer = model.config.id2label[idx]
predicted_answer.set(answer)
print (f"Predicted answer: {answer}")
thumbnail_label = None
root = tk.Tk()
root.title("Vilt VQA GUI")
tk.Label(root, text="Image Path:").pack(pady=5)
entry_image = tk.Entry(root, width=50)
entry_image.pack(pady=5)
browse_button = tk.Button(root, text="Browse", command=browse_image)
browse_button.pack(pady=5)
browse_button.bind('<Return>', browse_image)
preview_var = tk.BooleanVar()
tk.Checkbutton(root, text="Preview Image", command=preview_checkbox_changed, variable=preview_var).pack(pady=5)
tk.Label(root, text="Question:").pack(pady=5)
entry_question = tk.Entry(root, width=50)
entry_question.pack(pady=5)
run_button = tk.Button(root, text="Run Prediction", command=run_prediction)
run_button.pack(pady=10)
run_button.bind('<Return>', run_prediction)
predicted_answer = tk.StringVar()
tk.Label(root, text="Predicted answer:").pack(pady=5)
predicted_answer_widget = tk.Entry(root, state='readonly', width=30, textvariable=predicted_answer)
predicted_answer_widget.pack(pady=10)
# Set some initial text
predicted_answer_widget.insert(0, predicted_answer)
root.mainloop()