-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_scrapper.py
112 lines (88 loc) · 3.52 KB
/
image_scrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import re
import time
import argparse
import requests
import io
import hashlib
import itertools
import base64
from PIL import Image
from multiprocessing import Pool
from selenium import webdriver
argument_parser = argparse.ArgumentParser(description='Download images using google image search')
argument_parser.add_argument('query', metavar='query', type=str, help='The query to download images from')
argument_parser.add_argument('--count', metavar='count', default=100, type=int, help='How many images to fetch')
argument_parser.add_argument('--label', metavar='label', type=str, help="The directory in which to store the images (images/<label>)", required=True)
def ensure_directory(path):
if not os.path.exists(path):
os.mkdir(path)
def largest_file(dir_path):
def parse_num(filename):
match = re.search('\d+', filename)
if match:
return int(match.group(0))
files = os.listdir(dir_path)
if len(files) != 0:
return max(filter(lambda x: x, map(parse_num, files)))
else:
return 0
def fetch_image_urls(query, images_to_download):
image_urls = set()
search_url = "https://www.google.com/search?safe=off&site=&tbm=isch&source=hp&q={q}&oq={q}&gs_l=img"
browser = webdriver.Firefox(executable_path=r'geckodriver.exe')
browser.maximize_window()
browser.implicitly_wait(5)
browser.get(search_url.format(q=query))
def scroll_to_bottom():
browser.execute_script("window.scrollTo(0, document.body.scrollHeight);")
time.sleep(2)
image_count = len(image_urls)
delta = 0
while image_count < images_to_download:
print("Found:", len(image_urls), "images")
scroll_to_bottom()
images = browser.find_elements_by_css_selector("img.rg_ic")
for img in images:
image_urls.add(img.get_attribute('src'))
delta = len(image_urls) - image_count
image_count = len(image_urls)
if delta == 0:
print("Can't find more images")
break
fetch_more_button = browser.find_element_by_css_selector("#smb.ksb")
if fetch_more_button:
browser.execute_script("document.querySelector('#smb.ksb').click();")
scroll_to_bottom()
browser.quit()
return image_urls
def persist_image(dir_image_src):
label_directory = dir_image_src[0]
image_src = dir_image_src[1]
size = (256, 256)
try:
image_content = requests.get(image_src).content
except requests.exceptions.InvalidSchema:
# image is probably base64 encoded
image_data = re.sub('^data:image/.+;base64,', '', image_src)
image_content = base64.b64decode(image_data)
except Exception as e:
print("could not read", e, image_src)
return False
image_file = io.BytesIO(image_content)
image = Image.open(image_file).convert('RGB')
resized = image.resize(size)
with open(label_directory + hashlib.sha1(image_content).hexdigest() + ".jpg", 'wb') as f:
resized.save(f, "JPEG", quality=100)
return True
if __name__ == '__main__':
args = argument_parser.parse_args()
ensure_directory('./images/')
query_directory = './images/' + args.label + "/"
ensure_directory(query_directory)
image_urls = fetch_image_urls(args.query, args.count)
values = [item for item in zip(itertools.cycle([query_directory]), image_urls)]
print("image count", len(image_urls))
pool = Pool(12)
results = pool.map(persist_image, values)
print("Images downloaded: ", len([r for r in results if r]))