用 YOLOX 训练白芍数据集

First Post:

Last Update:

一.制作部分训练集

由于 YOLOX 算法模型主要采取COCO数据集或者VOC数据集

所以这里我们将采用以上两者数据集格式之一的VOC数据集格式来制作我们的部分数据集

:::details VOC数据集格式
VOC.png
:::

:::tip
在 Windows 上采用 精灵标注助手 或者 Labelme 进行数据集的标注

在 Linux ( Ubantu/Debian ) 上可采取 Labelme 进行数据集的标注
:::

:::details 以下是我们用到的白芍图片集
Baishao.png
:::

:::tip

请事先将所需标注的图片复制到单一文件夹中

使用 精灵标注助手 或者 Labelme 进行标注

注意 : 标注分类仅需分成 白芍 一种即可

标注时仅需用矩形标注出白芍的位置即可

并最终导出 100 张 左右的 XML 文件

根据上述的 VOC数据集格式
将 XML 文件 放入 Annotations 文件夹中

将 已标注的图片文件 放入 JPEGImages 文件夹中

在 ImageSets 当中 新建 main 文件夹 并创建 test.txt train.txt trainval.txt val.txt

使用 Python 等语言 编写 将JPEGImages当中的文件名存储为txt文件的功能脚本

注:

  • 将训练集的图片名存入 train.txt 文件当中

  • 将验证集的图片名存入 val.txt 文件当中

  • 并将两者图片名都存入 trainval.txt 文件当中

  • 将测试集图片名存入 test.txt 文件当中

:::

:::warning
训练集和验证集图片一般七三分
:::
二.训练标注用模型

:::tip
构建 YOLOX 运行环境
:::
git clone git@github.com:Megvii-BaseDetection/YOLOX.git
cd YOLOX
pip3 install -v -e . # or python3 setup.py develop
:::warning
也可选择以下方式
:::
https://github.com/Megvii-BaseDetection/YOLOX/archive/refs/heads/main.zip
下载项目zip文件 解压至指定目录
可构建虚拟环境或否
pip3 install -r 目标路径/requirments.txt
安装依赖
:::tip
修改项目当中的 exp/example/yolox_voc/yolox_voc_s.py 文件
:::
:::details 大致如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# encoding: utf-8
import os

import torch
import torch.distributed as dist

from yolox.data import get_yolox_datadir
from yolox.exp import Exp as MyExp


class Exp(MyExp):
def __init__(self):
super(Exp, self).__init__()
self.num_classes = 2
self.depth = 0.33
self.width = 0.50
self.warmup_epochs = 1

# ---------- transform config ------------ #
self.mosaic_prob = 1.0
self.mixup_prob = 1.0
self.hsv_prob = 1.0
self.flip_prob = 0.5

self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]

def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
from yolox.data import (
VOCDetection,
TrainTransform,
YoloBatchSampler,
DataLoader,
InfiniteSampler,
MosaicDetection,
worker_init_reset_seed,
)
from yolox.utils import (
wait_for_the_master,
get_local_rank,
)
local_rank = get_local_rank()

with wait_for_the_master(local_rank):
dataset = VOCDetection(
data_dir= "数据集目录", #os.path.join(get_yolox_datadir(), "VOCdevkit"),
image_sets=[('2007', 'trainval')], #, ('2012', 'trainval')
img_size=self.input_size,
preproc=TrainTransform(
max_labels=50,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob),
cache=cache_img,
)

dataset = MosaicDetection(
dataset,
mosaic=not no_aug,
img_size=self.input_size,
preproc=TrainTransform(
max_labels=120,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob),
degrees=self.degrees,
translate=self.translate,
mosaic_scale=self.mosaic_scale,
mixup_scale=self.mixup_scale,
shear=self.shear,
enable_mixup=self.enable_mixup,
mosaic_prob=self.mosaic_prob,
mixup_prob=self.mixup_prob,
)

self.dataset = dataset

if is_distributed:
batch_size = batch_size // dist.get_world_size()

sampler = InfiniteSampler(
len(self.dataset), seed=self.seed if self.seed else 0
)

batch_sampler = YoloBatchSampler(
sampler=sampler,
batch_size=batch_size,
drop_last=False,
mosaic=not no_aug,
)

dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
dataloader_kwargs["batch_sampler"] = batch_sampler

# Make sure each process has different random seed, especially for 'fork' method
dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed

train_loader = DataLoader(self.dataset, **dataloader_kwargs)

return train_loader

def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
from yolox.data import VOCDetection, ValTransform

valdataset = VOCDetection(
data_dir="D:\\Code\\XZR\\VOCdevkit", #os.path.join(get_yolox_datadir(), "VOCdevkit"),
image_sets=[('2007', 'test')],
img_size=self.test_size,
preproc=ValTransform(legacy=legacy),
)

if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = torch.utils.data.distributed.DistributedSampler(
valdataset, shuffle=False
)
else:
sampler = torch.utils.data.SequentialSampler(valdataset)

dataloader_kwargs = {
"num_workers": self.data_num_workers,
"pin_memory": True,
"sampler": sampler,
}
dataloader_kwargs["batch_size"] = batch_size
val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)

return val_loader

def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
from yolox.evaluators import VOCEvaluator

val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy)
evaluator = VOCEvaluator(
dataloader=val_loader,
img_size=self.test_size,
confthre=self.test_conf,
nmsthre=self.nmsthre,
num_classes=self.num_classes,
)
return evaluator

:::

:::tip
修改项目中的 yolox/data/datasets/voc_calsses.py
:::
:::details 大致如下

1
2
3
VOC_CLASSES = (
"baishao",
)

:::
:::tip

进行训练

1
python ./tools/train.py -f 指定实验描述文件-即yolox_voc_s.py(请取相对路径或绝对路径) -e 训练的轮次-建议300以上(只嫌少不嫌多,具体看训练的情况) --fp16 混合精度 -c 指定预训练权重文件-需官网自行下载

权重文件下载:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
╒══════════════════╤═════════════════════════════════════╕
│ keys │ values │
╞══════════════════╪═════════════════════════════════════╡
│ seed │ None │
├──────────────────┼─────────────────────────────────────┤
│ output_dir │ './YOLOX_outputs' │
├──────────────────┼─────────────────────────────────────┤
│ print_interval │ 10 │
├──────────────────┼─────────────────────────────────────┤
│ eval_interval │ 10 │
├──────────────────┼─────────────────────────────────────┤
│ num_classes │ 2 │
├──────────────────┼─────────────────────────────────────┤
│ depth │ 0.33 │
├──────────────────┼─────────────────────────────────────┤
│ width │ 0.25 │
├──────────────────┼─────────────────────────────────────┤
│ act │ 'silu' │
├──────────────────┼─────────────────────────────────────┤
│ data_num_workers │ 16 │
├──────────────────┼─────────────────────────────────────┤
│ input_size │ (320, 320) │
├──────────────────┼─────────────────────────────────────┤
│ multiscale_range │ 5 │
├──────────────────┼─────────────────────────────────────┤
│ data_dir │ '/media/mulong/新加卷/VOC2007/coco' │
├──────────────────┼─────────────────────────────────────┤
│ train_ann │ 'instances_train2017.json' │
├──────────────────┼─────────────────────────────────────┤
│ val_ann │ 'instances_val2017.json' │
├──────────────────┼─────────────────────────────────────┤
│ test_ann │ 'instances_test2017.json' │
├──────────────────┼─────────────────────────────────────┤
│ mosaic_prob │ 0.5 │
├──────────────────┼─────────────────────────────────────┤
│ mixup_prob │ 1.0 │
├──────────────────┼─────────────────────────────────────┤
│ hsv_prob │ 1.0 │
├──────────────────┼─────────────────────────────────────┤
│ flip_prob │ 0.5 │
├──────────────────┼─────────────────────────────────────┤
│ degrees │ 200.0 │
├──────────────────┼─────────────────────────────────────┤
│ translate │ 0.1 │
├──────────────────┼─────────────────────────────────────┤
│ mosaic_scale │ (0.1, 2) │
├──────────────────┼─────────────────────────────────────┤
│ mixup_scale │ (0.5, 2) │
├──────────────────┼─────────────────────────────────────┤
│ shear │ 2.0 │
├──────────────────┼─────────────────────────────────────┤
│ enable_mixup │ True │
├──────────────────┼─────────────────────────────────────┤
│ warmup_epochs │ 5 │
├──────────────────┼─────────────────────────────────────┤
│ max_epoch │ 300 │
├──────────────────┼─────────────────────────────────────┤
│ warmup_lr │ 0 │
├──────────────────┼─────────────────────────────────────┤
│ basic_lr_per_img │ 0.00015625 │
├──────────────────┼─────────────────────────────────────┤
│ scheduler │ 'yoloxwarmcos' │
├──────────────────┼─────────────────────────────────────┤
│ no_aug_epochs │ 15 │
├──────────────────┼─────────────────────────────────────┤
│ min_lr_ratio │ 0.05 │
├──────────────────┼─────────────────────────────────────┤
│ ema │ True │
├──────────────────┼─────────────────────────────────────┤
│ weight_decay │ 0.0005 │
├──────────────────┼─────────────────────────────────────┤
│ momentum │ 0.9 │
├──────────────────┼─────────────────────────────────────┤
│ exp_name │ 'yolox_nano_baishao_all' │
├──────────────────┼─────────────────────────────────────┤
│ test_size │ (320, 320) │
├──────────────────┼─────────────────────────────────────┤
│ test_conf │ 0.01 │
├──────────────────┼─────────────────────────────────────┤
│ nmsthre │ 0.65 │
╘══════════════════╧═════════════════════════════════════╛
class AP class AP
f 98.405 t 96.840
class AR class AR
f 99.072 t 97.893
:::
:::tip
训练结果应有

best_ckpt.pth

last_epoch_ckpt.pth

latest_ckpt.pth
等权重文件
:::

:::warning
注意:在模型描述文件以及demo.py当中 仍可能存在import缺少,函数内指定数据集不为VOC的情况,请自行修改
:::
三.制作完整数据集

:::tip
新建一个 demo.py ,并内容修改如下
:::
:::details demo.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
import os
import time
import uuid
from pathlib import Path

import numpy as np
import torchvision
from loguru import logger

import cv2
import torch

from yolox.data.data_augment import ValTransform
from yolox.data.datasets import COCO_CLASSES
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, vis

IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]


def make_parser():
parser = argparse.ArgumentParser("YOLOX Demo!")
parser.add_argument(
"demo", default="image", help="demo type, eg. image, video and webcam"
)
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")

parser.add_argument(
"--path", default="./assets/dog.jpg", help="path to images or video"
)
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
"--save_result",
action="store_true",
help="whether to save the inference result of image/video",
)

# exp file
parser.add_argument(
"-f",
"--exp_file",
default=None,
type=str,
help="pls input your experiment description file",
)
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
parser.add_argument(
"--device",
default="cpu",
type=str,
help="device to run our model, can either be cpu or gpu",
)
parser.add_argument("--conf", default=0.3, type=float, help="test conf")
parser.add_argument("--nms", default=0.3, type=float, help="test nms threshold")
parser.add_argument("--tsize", default=None, type=int, help="test img size")
parser.add_argument(
"--fp16",
dest="fp16",
default=False,
action="store_true",
help="Adopting mix precision evaluating.",
)
parser.add_argument(
"--legacy",
dest="legacy",
default=False,
action="store_true",
help="To be compatible with older versions",
)
parser.add_argument(
"--fuse",
dest="fuse",
default=False,
action="store_true",
help="Fuse conv and bn for testing.",
)
parser.add_argument(
"--trt",
dest="trt",
default=False,
action="store_true",
help="Using TensorRT model for testing.",
)
parser.add_argument(
"--crop",
dest="crop",
default=False,
action="store_true",
help="whether to crop the inference result of image",
)
return parser


def get_image_list(path):
image_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext.lower() in IMAGE_EXT:
image_names.append(apath)
return image_names


def scale_bbox(bboxes, scale_size=1.5, size=None):
# type: (Any, float, Any) -> Union[Optional[ndarray], Any]
# bboxes = torch.cat(bboxes).view(-1, 4)
if isinstance(scale_size, (int, float, complex)):
scale_size = torch.tensor(scale_size)
scale_size = torch.sqrt(scale_size)
bboxes_tmp = bboxes.new(bboxes.shape)

bboxes_tmp[3] = (bboxes[3] + bboxes[1]) / 2 + scale_size * (
(bboxes[3] - bboxes[1]) / 2
)
bboxes_tmp[1] = (bboxes[3] + bboxes[1]) / 2 - scale_size * (
(bboxes[3] - bboxes[1]) / 2
)
bboxes_tmp[2] = (bboxes[2] + bboxes[0]) / 2 + scale_size * (
(bboxes[2] - bboxes[0]) / 2
)
bboxes_tmp[0] = (bboxes[2] + bboxes[0]) / 2 - scale_size * (
(bboxes[2] - bboxes[0]) / 2
)
return torchvision.ops.clip_boxes_to_image(bboxes_tmp, size)


class Predictor(object):
def __init__(
self,
model,
exp,
cls_names=COCO_CLASSES,
trt_file=None,
decoder=None,
device="cpu",
fp16=False,
legacy=False,
):
self.model = model
self.cls_names = cls_names
self.decoder = decoder
self.num_classes = exp.num_classes
self.confthre = exp.test_conf
self.nmsthre = exp.nmsthre
self.test_size = exp.test_size
self.device = device
self.fp16 = fp16
self.preproc = ValTransform(legacy=legacy)
if trt_file is not None:
from torch2trt import TRTModule

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(trt_file))

x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
self.model(x)
self.model = model_trt

def inference(self, img):
img_info = {"id": 0}
if isinstance(img, str):
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info["file_name"] = None

height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img

ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
img_info["ratio"] = ratio

img, _ = self.preproc(img, None, self.test_size)
img = torch.from_numpy(img).unsqueeze(0)
img = img.float()
if self.device == "gpu":
img = img.cuda()
if self.fp16:
img = img.half() # to FP16

with torch.no_grad():
t0 = time.time()
outputs = self.model(img)
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs,
self.num_classes,
self.confthre,
self.nmsthre,
class_agnostic=True,
)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info

def visual(self, output, img_info, cls_conf=0.35):
global voc_write, crop
ratio = img_info["ratio"]
img = img_info["raw_img"]
if output is None:
return img
output = output.cpu()

bboxes = output[:, 0:4]

# preprocessing: resize
bboxes /= ratio

cls = output[:, 6]
scores = output[:, 4] * output[:, 5]
if crop:
frame = img.copy()
for i in range(len(bboxes)):
box = bboxes[i].numpy()
score = scores[i]
if score < cls_conf:
continue
if crop:

pad = 100 # 20
ymin = np.int0([box[1] - np.random.randint(2, pad), 0]).max()
xmin = np.int0([box[0] - np.random.randint(2, pad), 0]).max()
ymax = (
np.round([box[3] + np.random.randint(2, pad), img_info["height"]])
.min()
.astype(int)
)
xmax = (
np.round([box[2] + np.random.randint(2, pad), img_info["width"]])
.min()
.astype(int)
)
crop_box = box - [xmin, ymin, xmin, ymin]

Path(f"crop/未硫熏白芍_{pad}").mkdir(parents=True, exist_ok=True)
# file_name = f"crop/{self.cls_names[int(cls[i])]}/{uuid.uuid4().int}"
file_name = f"crop/未硫熏白芍_{pad}/{uuid.uuid4().int}"
cv2.imwrite(f"{file_name}.jpg", frame[ymin:ymax, xmin:xmax])
voc_crop_write = Writer(f"{file_name}.jpg", xmax - xmin, ymax - ymin)
voc_crop_write.addObject(self.cls_names[int(cls[i])], *crop_box)
voc_crop_write.save(f"{file_name}.xml")

voc_write.addObject(
self.cls_names[int(cls[i])], box[0], box[1], box[2], box[3]
)

vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
return vis_res


def image_demo(predictor, vis_folder, path, current_time, save_result):
global voc_write, crop, crop_frames

if os.path.isdir(path):
files = get_image_list(path)
else:
files = [path]
files.sort()
for image_name in files:
outputs, img_info = predictor.inference(image_name)
voc_write = Writer(image_name, img_info["height"], img_info["width"])
result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
if save_result:
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
save_file_name = os.path.join(save_folder, os.path.basename(image_name))
logger.info("Saving detection result in {}".format(save_file_name))
cv2.imwrite(save_file_name, result_image)

# if crop:
# save_folder1 = Path(save_folder)/"crop"
# save_folder1.mkdir(exist_ok=True)
# save_file_name1 = save_folder1/f"{uuid.uuid4().int}"
# for i in crop_frames:
# cv2.imwrite(f"{save_file_name1}.jpg", i)
# voc_write.save(f"{save_file_name1}.xml")
# else:
voc_write.save(f"{save_folder}/{Path(save_file_name).stem}.xml")

# ch = cv2.waitKey(0)
# if ch == 27 or ch == ord("q") or ch == ord("Q"):
# break


def imageflow_demo(predictor, vis_folder, current_time, args):
cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
fps = cap.get(cv2.CAP_PROP_FPS)
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
if args.demo == "video":
save_path = os.path.join(save_folder, args.path.split("/")[-1])
else:
save_path = os.path.join(save_folder, "camera.mp4")
logger.info(f"video save_path is {save_path}")
vid_writer = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)
while True:
ret_val, frame = cap.read()
if ret_val:
outputs, img_info = predictor.inference(frame)
result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)
if args.save_result:
vid_writer.write(result_frame)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break


def main(exp, args):
if not args.experiment_name:
args.experiment_name = exp.exp_name

file_name = os.path.join(exp.output_dir, args.experiment_name)
os.makedirs(file_name, exist_ok=True)

vis_folder = None
if args.save_result:
vis_folder = os.path.join(file_name, "vis_res")
os.makedirs(vis_folder, exist_ok=True)

if args.trt:
args.device = "gpu"

logger.info("Args: {}".format(args))

if args.conf is not None:
exp.test_conf = args.conf
if args.nms is not None:
exp.nmsthre = args.nms
if args.tsize is not None:
exp.test_size = (args.tsize, args.tsize)

model = exp.get_model()
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))

if args.device == "gpu":
model.cuda()
if args.fp16:
model.half() # to FP16
model.eval()

if not args.trt:
if args.ckpt is None:
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
logger.info("loading checkpoint")
ckpt = torch.load(ckpt_file, map_location="cpu")
# load the model state dict
model.load_state_dict(ckpt["model"])
logger.info("loaded checkpoint done.")

if args.fuse:
logger.info("\tFusing model...")
model = fuse_model(model)

if args.trt:
assert not args.fuse, "TensorRT model is not support model fusing!"
trt_file = os.path.join(file_name, "model_trt.pth")
assert os.path.exists(
trt_file
), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
model.head.decode_in_inference = False
decoder = model.head.decode_outputs
logger.info("Using TensorRT to inference")
else:
trt_file = None
decoder = None

predictor = Predictor(
model,
exp,
COCO_CLASSES,
trt_file,
decoder,
args.device,
args.fp16,
args.legacy,
)
current_time = time.localtime()
if args.demo == "image":
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
elif args.demo == "video" or args.demo == "webcam":
imageflow_demo(predictor, vis_folder, current_time, args)


if __name__ == "__main__":
global voc_write, crop, crop_frames
from pascal_voc_writer import Writer

with logger.catch():
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
crop = args.crop
main(exp, args)

:::
:::tip
依据 train.py 文件,对照 demo.py 文件 填写对应参数

此时注意 权重文件应改为先前训练后的得到的pth权重文件

在训练之前 将图片集目光回到之前所收集的未熏和已熏两个文件夹

根据所需标注类型 将 demo.py 当中 visual函数内的路径变量内容进行改正

然后运行 demo.py 得到所需 xml 文件

与第一步同理

此时将两类图片杂糅一起,制作完整的数据集即可
:::
:::warning
在已修改的 demo.py 文件当中 涉及 opencv 以及获取 ret 坐标进行制作 xml 文件的代码

并且也包含了基本的文件创建编写的内容

在之后的项目当中会用到,例如API接口等,建议学习
:::
四.训练最终模型

:::tip
与第二步相同处理步骤
:::