Skip to content

Commit 2519854

Browse files
committed
Restore track.py
1 parent 2142f2f commit 2519854

File tree

1 file changed

+370
-0
lines changed

1 file changed

+370
-0
lines changed

tools/track.py

+370
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
import argparse
2+
import os
3+
import sys
4+
import os.path as osp
5+
import cv2
6+
import numpy as np
7+
import torch
8+
9+
sys.path.append('.')
10+
11+
from loguru import logger
12+
13+
from yolox.data.data_augment import preproc
14+
from yolox.exp import get_exp
15+
from yolox.utils import fuse_model, get_model_info, postprocess
16+
from yolox.utils.visualize import plot_tracking
17+
18+
from tracker.tracking_utils.timer import Timer
19+
from tracker.bot_sort import BoTSORT
20+
21+
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
22+
23+
# Global
24+
trackerTimer = Timer()
25+
timer = Timer()
26+
27+
28+
def make_parser():
29+
parser = argparse.ArgumentParser("BoT-SORT Tracks For Evaluation!")
30+
31+
parser.add_argument("path", help="path to dataset under evaluation, currently only support MOT17 and MOT20.")
32+
parser.add_argument("--benchmark", dest="benchmark", type=str, default='MOT17', help="benchmark to evaluate: MOT17 | MOT20")
33+
parser.add_argument("--eval", dest="split_to_eval", type=str, default='test', help="split to evaluate: train | val | test")
34+
parser.add_argument("-f", "--exp_file", default=None, type=str, help="pls input your expriment description file")
35+
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
36+
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
37+
parser.add_argument("--default-parameters", dest="default_parameters", default=False, action="store_true", help="use the default parameters as in the paper")
38+
parser.add_argument("--save-frames", dest="save_frames", default=False, action="store_true", help="save sequences with tracks.")
39+
40+
# Detector
41+
parser.add_argument("--device", default="gpu", type=str, help="device to run our model, can either be cpu or gpu")
42+
parser.add_argument("--conf", default=None, type=float, help="test conf")
43+
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
44+
parser.add_argument("--tsize", default=None, type=int, help="test img size")
45+
parser.add_argument("--fp16", dest="fp16", default=False, action="store_true", help="Adopting mix precision evaluating.")
46+
parser.add_argument("--fuse", dest="fuse", default=False, action="store_true", help="Fuse conv and bn for testing.")
47+
48+
# tracking args
49+
parser.add_argument("--track_high_thresh", type=float, default=0.6, help="tracking confidence threshold")
50+
parser.add_argument("--track_low_thresh", default=0.1, type=float, help="lowest detection threshold valid for tracks")
51+
parser.add_argument("--new_track_thresh", default=0.7, type=float, help="new track thresh")
52+
parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
53+
parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking")
54+
parser.add_argument("--aspect_ratio_thresh", type=float, default=1.6, help="threshold for filtering out boxes of which aspect ratio are above the given value.")
55+
parser.add_argument('--min_box_area', type=float, default=10, help='filter out tiny boxes')
56+
57+
# CMC
58+
parser.add_argument("--cmc-method", default="file", type=str, help="cmc method: files (Vidstab GMC) | sparseOptFlow | orb | ecc | none")
59+
60+
# ReID
61+
parser.add_argument("--with-reid", dest="with_reid", default=False, action="store_true", help="use Re-ID flag.")
62+
parser.add_argument("--fast-reid-config", dest="fast_reid_config", default=r"fast_reid/configs/MOT17/sbs_S50.yml", type=str, help="reid config file path")
63+
parser.add_argument("--fast-reid-weights", dest="fast_reid_weights", default=r"pretrained/mot17_sbs_S50.pth", type=str, help="reid config file path")
64+
parser.add_argument('--proximity_thresh', type=float, default=0.5, help='threshold for rejecting low overlap reid matches')
65+
parser.add_argument('--appearance_thresh', type=float, default=0.25, help='threshold for rejecting low appearance similarity reid matches')
66+
67+
return parser
68+
69+
70+
def get_image_list(path):
71+
image_names = []
72+
for maindir, subdir, file_name_list in os.walk(path):
73+
for filename in file_name_list:
74+
apath = osp.join(maindir, filename)
75+
ext = osp.splitext(apath)[1]
76+
if ext in IMAGE_EXT:
77+
image_names.append(apath)
78+
return image_names
79+
80+
81+
def write_results(filename, results):
82+
save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
83+
with open(filename, 'w') as f:
84+
for frame_id, tlwhs, track_ids, scores in results:
85+
for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
86+
if track_id < 0:
87+
continue
88+
x1, y1, w, h = tlwh
89+
line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1),
90+
h=round(h, 1), s=round(score, 2))
91+
f.write(line)
92+
logger.info('save results to {}'.format(filename))
93+
94+
95+
class Predictor(object):
96+
def __init__(
97+
self,
98+
model,
99+
exp,
100+
device=torch.device("cpu"),
101+
fp16=False
102+
):
103+
self.model = model
104+
self.num_classes = exp.num_classes
105+
self.confthre = exp.test_conf
106+
self.nmsthre = exp.nmsthre
107+
self.test_size = exp.test_size
108+
self.device = device
109+
self.fp16 = fp16
110+
111+
self.rgb_means = (0.485, 0.456, 0.406)
112+
self.std = (0.229, 0.224, 0.225)
113+
114+
def inference(self, img, timer):
115+
img_info = {"id": 0}
116+
if isinstance(img, str):
117+
img_info["file_name"] = osp.basename(img)
118+
img = cv2.imread(img)
119+
else:
120+
img_info["file_name"] = None
121+
122+
if img is None:
123+
raise ValueError("Empty image: ", img_info["file_name"])
124+
125+
height, width = img.shape[:2]
126+
img_info["height"] = height
127+
img_info["width"] = width
128+
img_info["raw_img"] = img
129+
130+
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
131+
img_info["ratio"] = ratio
132+
img = torch.from_numpy(img).unsqueeze(0).float().to(self.device)
133+
if self.fp16:
134+
img = img.half() # to FP16
135+
136+
with torch.no_grad():
137+
timer.tic()
138+
outputs = self.model(img)
139+
outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre)
140+
141+
return outputs, img_info
142+
143+
144+
def image_track(predictor, vis_folder, args):
145+
if osp.isdir(args.path):
146+
files = get_image_list(args.path)
147+
else:
148+
files = [args.path]
149+
files.sort()
150+
151+
if args.ablation:
152+
files = files[len(files) // 2 + 1:]
153+
154+
num_frames = len(files)
155+
156+
# Tracker
157+
tracker = BoTSORT(args, frame_rate=args.fps)
158+
159+
results = []
160+
161+
for frame_id, img_path in enumerate(files, 1):
162+
163+
# Detect objects
164+
outputs, img_info = predictor.inference(img_path, timer)
165+
scale = min(exp.test_size[0] / float(img_info['height'], ), exp.test_size[1] / float(img_info['width']))
166+
167+
if outputs[0] is not None:
168+
outputs = outputs[0].cpu().numpy()
169+
detections = outputs[:, :7]
170+
detections[:, :4] /= scale
171+
172+
trackerTimer.tic()
173+
online_targets = tracker.update(detections, img_info["raw_img"])
174+
trackerTimer.toc()
175+
176+
online_tlwhs = []
177+
online_ids = []
178+
online_scores = []
179+
for t in online_targets:
180+
tlwh = t.tlwh
181+
tid = t.track_id
182+
vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
183+
if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
184+
online_tlwhs.append(tlwh)
185+
online_ids.append(tid)
186+
online_scores.append(t.score)
187+
188+
# save results
189+
results.append(
190+
f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
191+
)
192+
timer.toc()
193+
online_im = plot_tracking(
194+
img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id, fps=1. / timer.average_time
195+
)
196+
else:
197+
timer.toc()
198+
online_im = img_info['raw_img']
199+
200+
if args.save_frames:
201+
save_folder = osp.join(vis_folder, args.name)
202+
os.makedirs(save_folder, exist_ok=True)
203+
cv2.imwrite(osp.join(save_folder, osp.basename(img_path)), online_im)
204+
205+
if frame_id % 20 == 0:
206+
logger.info('Processing frame {}/{} ({:.2f} fps)'.format(frame_id, num_frames, 1. / max(1e-5, timer.average_time)))
207+
208+
res_file = osp.join(vis_folder, args.name + ".txt")
209+
210+
with open(res_file, 'w') as f:
211+
f.writelines(results)
212+
logger.info(f"save results to {res_file}")
213+
214+
215+
def main(exp, args):
216+
if not args.experiment_name:
217+
args.experiment_name = exp.exp_name
218+
219+
output_dir = osp.join(exp.output_dir, args.experiment_name)
220+
os.makedirs(output_dir, exist_ok=True)
221+
222+
vis_folder = osp.join(output_dir, "track_results")
223+
os.makedirs(vis_folder, exist_ok=True)
224+
225+
args.device = torch.device("cuda" if args.device == "gpu" else "cpu")
226+
227+
logger.info("Args: {}".format(args))
228+
229+
if args.conf is not None:
230+
exp.test_conf = args.conf
231+
if args.nms is not None:
232+
exp.nmsthre = args.nms
233+
if args.tsize is not None:
234+
exp.test_size = (args.tsize, args.tsize)
235+
236+
model = exp.get_model().to(args.device)
237+
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
238+
model.eval()
239+
240+
if args.ckpt is None:
241+
ckpt_file = osp.join(output_dir, "best_ckpt.pth.tar")
242+
else:
243+
ckpt_file = args.ckpt
244+
logger.info("loading checkpoint")
245+
ckpt = torch.load(ckpt_file, map_location="cpu")
246+
247+
# load the model state dict
248+
model.load_state_dict(ckpt["model"])
249+
logger.info("loaded checkpoint done.")
250+
251+
if args.fuse:
252+
logger.info("\tFusing model...")
253+
model = fuse_model(model)
254+
255+
if args.fp16:
256+
model = model.half() # to FP16
257+
258+
predictor = Predictor(model, exp, args.device, args.fp16)
259+
260+
image_track(predictor, vis_folder, args)
261+
262+
263+
if __name__ == "__main__":
264+
args = make_parser().parse_args()
265+
266+
data_path = args.path
267+
fp16 = args.fp16
268+
device = args.device
269+
270+
if args.benchmark == 'MOT20':
271+
train_seqs = [1, 2, 3, 5]
272+
test_seqs = [4, 6, 7, 8]
273+
seqs_ext = ['']
274+
MOT = 20
275+
elif args.benchmark == 'MOT17':
276+
train_seqs = [2, 4, 5, 9, 10, 11, 13]
277+
test_seqs = [1, 3, 6, 7, 8, 12, 14]
278+
seqs_ext = ['FRCNN', 'DPM', 'SDP']
279+
MOT = 17
280+
else:
281+
raise ValueError("Error: Unsupported benchmark:" + args.benchmark)
282+
283+
ablation = False
284+
if args.split_to_eval == 'train':
285+
seqs = train_seqs
286+
elif args.split_to_eval == 'val':
287+
seqs = train_seqs
288+
ablation = True
289+
elif args.split_to_eval == 'test':
290+
seqs = test_seqs
291+
else:
292+
raise ValueError("Error: Unsupported split to evaluate:" + args.split_to_eval)
293+
294+
mainTimer = Timer()
295+
mainTimer.tic()
296+
297+
for ext in seqs_ext:
298+
for i in seqs:
299+
if i < 10:
300+
seq = 'MOT' + str(MOT) + '-0' + str(i)
301+
else:
302+
seq = 'MOT' + str(MOT) + '-' + str(i)
303+
304+
if ext != '':
305+
seq += '-' + ext
306+
307+
args.name = seq
308+
309+
args.ablation = ablation
310+
args.mot20 = MOT == 20
311+
args.fps = 30
312+
args.device = device
313+
args.fp16 = fp16
314+
args.batch_size = 1
315+
args.trt = False
316+
317+
split = 'train' if i in train_seqs else 'test'
318+
args.path = data_path + '/' + split + '/' + seq + '/' + 'img1'
319+
320+
if args.default_parameters:
321+
322+
if MOT == 20: # MOT20
323+
args.exp_file = r'./yolox/exps/example/mot/yolox_x_mix_mot20_ch.py'
324+
args.ckpt = r'./pretrained/bytetrack_x_mot20.tar'
325+
args.match_thresh = 0.7
326+
else: # MOT17
327+
if ablation:
328+
args.exp_file = r'./yolox/exps/example/mot/yolox_x_ablation.py'
329+
args.ckpt = r'./pretrained/bytetrack_ablation.pth.tar'
330+
else:
331+
args.exp_file = r'./yolox/exps/example/mot/yolox_x_mix_det.py'
332+
args.ckpt = r'./pretrained/bytetrack_x_mot17.pth.tar'
333+
334+
exp = get_exp(args.exp_file, args.name)
335+
336+
args.track_high_thresh = 0.6
337+
args.track_low_thresh = 0.1
338+
args.track_buffer = 30
339+
340+
if seq == 'MOT17-05-FRCNN' or seq == 'MOT17-06-FRCNN':
341+
args.track_buffer = 14
342+
elif seq == 'MOT17-13-FRCNN' or seq == 'MOT17-14-FRCNN':
343+
args.track_buffer = 25
344+
else:
345+
args.track_buffer = 30
346+
347+
if seq == 'MOT17-01-FRCNN':
348+
args.track_high_thresh = 0.65
349+
elif seq == 'MOT17-06-FRCNN':
350+
args.track_high_thresh = 0.65
351+
elif seq == 'MOT17-12-FRCNN':
352+
args.track_high_thresh = 0.7
353+
elif seq == 'MOT17-14-FRCNN':
354+
args.track_high_thresh = 0.67
355+
elif seq in ['MOT20-06', 'MOT20-08']:
356+
args.track_high_thresh = 0.3
357+
exp.test_size = (736, 1920)
358+
359+
args.new_track_thresh = args.track_high_thresh + 0.1
360+
else:
361+
exp = get_exp(args.exp_file, args.name)
362+
363+
exp.test_conf = max(0.001, args.track_low_thresh - 0.01)
364+
main(exp, args)
365+
366+
mainTimer.toc()
367+
print("TOTAL TIME END-to-END (with loading networks and images): ", mainTimer.total_time)
368+
print("TOTAL TIME (Detector + Tracker): " + str(timer.total_time) + ", FPS: " + str(1.0 /timer.average_time))
369+
print("TOTAL TIME (Tracker only): " + str(trackerTimer.total_time) + ", FPS: " + str(1.0 / trackerTimer.average_time))
370+

0 commit comments

Comments
 (0)