1
+ import torch
1
2
from fastdup .sentry import fastdup_capture_exception
2
3
from fastdup .definitions import MISSING_LABEL
3
4
from fastdup .galleries import fastdup_imread
4
5
import cv2
6
+ from tqdm import tqdm
5
7
8
+ device_to_captioner = {}
9
+
10
+ def init_captioning (model_name = 'automatic' , device = 'cpu' , batch_size = 8 , max_new_tokens = 20 ,
11
+ use_float_16 = True ):
6
12
7
- def generate_labels (filenames , model_name = 'automatic' , device = 'cpu' , batch_size = 8 ):
8
13
'''
9
14
This function generates captions for a given set of images, and takes the following arguments:
10
15
- filenames: the list of images passed to the function
@@ -15,64 +20,82 @@ def generate_labels(filenames, model_name='automatic', device = 'cpu', batch_siz
15
20
- BLIP: 'blip'
16
21
- batch_size: the size of image batches to caption (default: 8)
17
22
- device: whether to use a GPU (default: -1, CPU only ; set to 0 for GPU)
23
+ - max_bew_tokens: set the number of allowed tokens
18
24
'''
25
+
26
+ global device_to_captioner
19
27
# use GPU if device is specified
20
28
if device == 'gpu' :
21
29
device = 0
22
30
elif device == 'cpu' :
23
31
device = - 1
32
+ use_float_16 = False
24
33
else :
25
- assert False , "Incompatible device name entered. Available device names are gpu and cpu."
34
+ assert False , "Incompatible device name entered {device} . Available device names are gpu and cpu."
26
35
27
36
# confirm necessary dependencies are installed, and import them
28
37
try :
29
38
from transformers import pipeline
30
39
from transformers .utils import logging
31
- logging .set_verbosity_info ()
32
- import torch
33
- from PIL import Image
34
- from tqdm import tqdm
40
+ logging .set_verbosity (50 )
41
+
35
42
except Exception as e :
36
43
fastdup_capture_exception ("Auto generate labels" , e )
37
44
print ("Auto captioning requires an installation of the following libraries:\n " )
38
- print (" huggingface transformers\n pytorch\n pillow \n tqdm \n " )
39
- print ("to install, use `pip install transformers torch pillow tqdm `" )
40
- return [ MISSING_LABEL ] * len ( filenames )
45
+ print (" huggingface transformers\n pytorch\n " )
46
+ print (" to install, use `pip3 install transformers torch`" )
47
+ raise
41
48
42
49
# dictionary of captioning models
43
50
models = {
44
51
'automatic' : "nlpconnect/vit-gpt2-image-captioning" ,
45
52
'vitgpt2' : "nlpconnect/vit-gpt2-image-captioning" ,
46
- 'blip2 ' : "Salesforce/blip2-opt-2.7b" ,
53
+ 'blip-2 ' : "Salesforce/blip2-opt-2.7b" ,
47
54
'blip' : "Salesforce/blip-image-captioning-large"
48
55
}
49
-
56
+ assert model_name in models . keys (), f"Unknown captioning model { model_name } allowed models are { models . keys () } "
50
57
model = models [model_name ]
58
+ has_gpu = torch .cuda .is_available ()
59
+ captioner = pipeline ("image-to-text" , model = model , device = device if has_gpu else "cpu" , max_new_tokens = max_new_tokens ,
60
+ torch_dtype = torch .float16 if use_float_16 else torch .float32 )
61
+ device_to_captioner [device ] = captioner
51
62
52
- # generate captions
53
- try :
54
- captioner = pipeline ("image-to-text" , model = model , device = device )
55
-
56
- captions = []
57
-
58
- for pred in captioner (filenames , batch_size = batch_size ):
59
- #caption = pred['generated_text']
60
- caption = '' .join ([d ['generated_text' ] for d in pred ])
61
- captions .append (caption )
63
+ return captioner
62
64
65
+ def generate_labels (filenames , model_name = 'automatic' , device = 'cpu' , batch_size = 8 , max_new_tokens = 20 , use_float_16 = True ):
66
+ global device_to_captioner
67
+ if device not in device_to_captioner :
68
+ captioner = init_captioning (model_name , device , batch_size , max_new_tokens , use_float_16 )
69
+ else :
70
+ captioner = device_to_captioner [device ]
63
71
64
- '''for image_path in tqdm(filenames):
65
- img = Image.open(image_path)
66
- pred = captioner(img)
67
- caption = pred[0]['generated_text']
68
- captions.append(caption)'''
69
- return captions
70
-
72
+ captions = []
73
+ # generate captions
74
+ try :
75
+ for i in tqdm (range (0 , len (filenames ), batch_size )):
76
+ chunk = filenames [i :i + batch_size ]
77
+ try :
78
+ for pred in captioner (chunk , batch_size = batch_size ):
79
+ charstring = '' if model_name != 'blip' else ' '
80
+ caption = charstring .join ([d ['generated_text' ] for d in pred ])
81
+ # Split the sentence into words
82
+ words = caption .split ()
83
+ # Filter out words containing '#'
84
+ filtered_words = [word for word in words if '#' not in word ]
85
+ # Join the filtered words back into a sentence
86
+ caption = ' ' .join (filtered_words )
87
+ caption = caption .strip ()
88
+ captions .append (caption )
89
+ except Exception as ex :
90
+ print ("Failed to caption chunk" , chunk [:5 ], ex )
91
+ captions .extend ([MISSING_LABEL ] * len (chunk ))
71
92
72
93
except Exception as e :
73
94
fastdup_capture_exception ("Auto caption image" , e )
74
95
return [MISSING_LABEL ] * len (filenames )
75
96
97
+ return captions
98
+
76
99
77
100
def generate_vqa_labels (filenames , text , kwargs ):
78
101
# confirm necessary dependencies are installed, and import them
@@ -156,3 +179,15 @@ def generate_age_labels(filenames, kwargs):
156
179
fastdup_capture_exception ("Age label" , e )
157
180
return [MISSING_LABEL ] * len (filenames )
158
181
182
+ if __name__ == "__main__" :
183
+ import fastdup
184
+ from fastdup .captions import generate_labels
185
+ file = "/Users/dannybickson/visual_database/cxx/unittests/two_images/"
186
+ import os
187
+ files = os .listdir (file )
188
+ files = [os .path .join (file , f ) for f in files ]
189
+ ret = generate_labels (files , model_name = 'blip' )
190
+ assert (len (ret ) == 2 )
191
+ print (ret )
192
+ for r in ret :
193
+ assert "shelf" in r or "shelves" in r or "store" in r
0 commit comments