Skip to content

Commit 5605542

Browse files
committedJan 16, 2021
Detect active learning
1 parent 4c910cc commit 5605542

File tree

3,012 files changed

+254
-206
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

3,012 files changed

+254
-206
lines changed
 

‎AL_detect.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import time
2+
import torch
3+
from models.experimental import attempt_load
4+
from utils.datasets import LoadImages
5+
from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh, set_logging
6+
from utils.torch_utils import select_device, time_synchronized
7+
import activelearning.config as config
8+
import time
9+
10+
def AL_detect(opt):
11+
# Detect ảnh
12+
# File weight model, nguồn dữ liệu detect, Kích cỡ ảnh sử dụng
13+
weights, source, imgsz = opt.weights, opt.source, opt.img_size
14+
15+
# khởi tạo
16+
device = select_device(opt.device)
17+
half = device.type != 'cpu'
18+
19+
# Load model
20+
model = attempt_load(weights=weights, map_location=device)
21+
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
22+
if half:
23+
model.half() # to FP16
24+
dataset = LoadImages(source, img_size=imgsz)
25+
26+
# Get names and colors
27+
names = model.module.names if hasattr(model, 'module') else model.names
28+
29+
# Run inference
30+
t0 = time.time()
31+
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
32+
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
33+
34+
result = {}
35+
# duyệt tất cả các ảnh
36+
for path, img, im0s, vid_cap in dataset:
37+
img = torch.from_numpy(img).to(device)
38+
img = img.half() if half else img.float() # uint8 to fp16/32
39+
# chuuẩn hóa hình ảnh
40+
img /= 255.0 # 0 - 255 to 0.0 - 1.0
41+
if img.ndimension() == 3:
42+
img = img.unsqueeze(0)
43+
44+
# Inference
45+
t1 = time_synchronized()
46+
pred = model(img, augment=opt.augment)[0]
47+
48+
# Apply NMS
49+
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
50+
t2 = time_synchronized()
51+
52+
# Process detections
53+
result[path] = []
54+
55+
for i, det in enumerate(pred): # detections per image
56+
p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
57+
s += '%gx%g ' % img.shape[2:] # print string
58+
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
59+
60+
if len(det):
61+
# Rescale boxes from img_size to im0 size
62+
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
63+
64+
# Print results
65+
for c in det[:, -1].unique():
66+
n = (det[:, -1] == c).sum() # detections per class
67+
s += f'{n} {names[int(c)]}s, ' # add to string
68+
69+
# Lưu thông tin về box vào 1 file
70+
for *xyxy, conf, cls in reversed(det):
71+
with open(config.info_predict_path, 'a') as f:
72+
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh)
73+
x,y,w,h = xywh
74+
data = {"class": cls.item(), "box": [x,y,w,h], "conf": conf.item()}
75+
result[path].append(data)
76+
# Print time (inference + NMS)
77+
print(f'{s}Done. ({t2 - t1:.3f}s)')
78+
print(f'Done. ({time.time() - t0:.3f}s)')
79+
return result
80+
81+

‎AL_run.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,15 @@ def __init__(self, model, select_function):
3232
self.model = model
3333
self.select_function = select_function
3434
self.num_select = config.num_select
35+
self.type = 'sum' # 'avg' , 'max', 'sum'
3536

3637
def run(self):
37-
self.queried = 0
38-
# chưa thỏa mãn điều kiện dừng thì tiếp tục lặp lại active learning
39-
while self.queried < config.max_queried:
40-
# Xoá file dự đoán cũ
41-
if os.path.exists(config.info_predict_path):
42-
os.remove(config.info_predict_path)
38+
# số truy vấn
39+
queried = 0
40+
# nếu chưa đủ số truy vấn thì tiếp tục truy vấn tiếp
41+
while queried < config.max_queried:
4342
# Dự đoán các ảnh trong tập unlabeled
44-
self.model.detect()
45-
# Kết quả sau khi dự đoán được lưu ở config.info_predict_path
43+
result = self.model.detect()
4644
# Tổng hợp kết quả
4745
probas = {str(file.split('/')[-1]): 0.0 for file in glob.glob(config.source + '/*')}
4846
num_object = probas.copy()

0 commit comments

Comments
 (0)
Please sign in to comment.