200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > MMrotate自定义数据集训练与验证格式转换脚本

MMrotate自定义数据集训练与验证格式转换脚本

时间:2021-02-17 11:53:48

相关推荐

MMrotate自定义数据集训练与验证格式转换脚本

数据集准备

数据集格式

文件夹格式:Data/ #保存Dota数据集的目录

Train #存放images和labelTxt的文件夹

Images#存放所有训练集图片的文件夹

labelTxt #存放所有训练集txt标注文件的文件夹

LabelTxt中的txt文件可通过转换脚本ro_xml2txt.py将RolabelImg标注的xml转换成DOTA格式的txt文件。

其中--xml_dir为需要转换的存放xml的路径。

--output_dir为转换后的数据集存放路径。

修改mmrotate/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_ROL_le90.py修改mmrotate/mmrotate/datasets/rolabel.py

三、训练

训练命令格式:

# 单 GPU 训练

python tools/train.py ${CONFIG_FILE} [optional arguments]

# 多 GPU 训练

bash tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]

说明:

config_file:模型配置文件的路径

gpu_num:使用 GPU 的数量

--work-dir:设置存放训练生成文件的路径

--resume-from:设置恢复训练的模型检查点文件的路径

--no-validate(不建议):设置训练时不验证模型

--seed:设置随机种子,便于复现结果

这里以oriented_rcnn为例,cd 到yuml_web目录下,运行命令:

Python mmrotate/tools/train.py /

mmrotate/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_ROL_le90.py

即可开始训练模型。其中训练产生的所有日志文件都保存在work_dir中。

验证

# 单 GPU 测试

python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \

[--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]

# 多 GPU 测试

bash tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} \

[--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]

config_file:模型配置文件的路径

checkpoint_file:模型检查点文件的路径

gpu_num:使用的 GPU 数量

--out:设置输出 pkl 测试结果文件的路径

--work-dir:设置存放 json 日志文件的路径

--eval:设置度量指标(voc:mAP, recall | coco:bbox, segm, proposal)

--show:设置显示有预测框的测试集图像

--show-dir:设置存放有预测框的测试集图像的路径

--show-score-thr:设置显示预测框的阈值,默认值为 0.3

--fuse-conv-bn: 设置融合卷积层和批归一化层,能够稍微提升推理速度

这里以oriented_rcnn为例,建议在work_dir中需要验证的pth模型文件复制到yuml_web/checkpoints/下,cd 到yuml_web目录下,运行命令:

python mmrotate/tools/test.py mmrotate/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_ROL_le90.py /

checkpoints/pth模型文件名 --show

注:验证的数据集为test.txt中的图片名

转换脚本:

# -*- coding: utf-8 -*-# @Time : /5/2 15:42# @Author : Bob.Xu# @Site : 根据rolabelimg标注的xml文件转换为txt格式文件# @File : xml2txt.py# @Software: PyCharm# 将标记后的xml文件转为advanceeast训练的格式import osimport xml.etree.ElementTree as ETimport shutilimport globimport timeimport mathimport argparsedef rotatePoint(xc, yc, xp, yp, theta):'''xc:x中心点yc:y中心点xp:x边长度yp:y边长度thete:旋转角度'''xoff = xp - xc;yoff = yp - yc;cosTheta = math.cos(theta)sinTheta = math.sin(theta)pResx = cosTheta * xoff + sinTheta * yoffpResy = - sinTheta * xoff + cosTheta * yoff# pRes = (xc + pResx, yc + pResy)return str(xc + pResx), str(yc + pResy)if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--xml_dir', type=str,default='yaodai_data/Annotations',help='Directory of images and xml.')parser.add_argument('--output_dir', type=str,default='yaodai_data/labelTxt/',help='Directory of output.')a=parser.parse_args()path = a.xml_dirfile = os.listdir(path)file = glob.glob(path + "/*.xml")output_dir = a.output_dirn = 0if not os.path.exists(output_dir):os.makedirs(output_dir)for filename in file:# start=time.time()first = os.path.splitext(filename)[0]last = os.path.splitext(filename)[1]if last == ".xml":# print(first,last)next = first.split("/")name = next[-1] + ".xml"n = n + 1print("正在处理第{}个xml文件,名称为{}".format(n, name))filetxt = first + ".txt"f = open(filetxt, 'w', encoding='utf-8')aa = []tree = ET.parse(filename)root = tree.getroot()for tt in root.iter("object"):if tt.find('bndbox'):lefttopx = tt.find("bndbox")[0].textlefttopy = tt.find("bndbox")[1].textrighttopx = tt.find("bndbox")[2].textrighttopy = tt.find("bndbox")[1].textrightdownx = tt.find("bndbox")[2].textrightdowny = tt.find("bndbox")[3].textleftdownx = tt.find("bndbox")[0].textleftdowny = tt.find("bndbox")[3].texttb = tt.find("name").textdf = tt.find("difficult").textaa = list([lefttopx, lefttopy, righttopx, righttopy, rightdownx, rightdowny, leftdownx, leftdowny, tb, df])bb = " ".join(aa)f.writelines(bb)f.writelines("\n")elif tt.find('robndbox'):cx = float(tt.find("robndbox")[0].text)cy = float(tt.find("robndbox")[1].text)w = float(tt.find("robndbox")[2].text)h = float(tt.find("robndbox")[3].text)angle = float(tt.find("robndbox")[4].text)lefttopx, lefttopy = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)righttopx, righttopy = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)rightdownx, rightdowny = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)leftdownx, leftdowny = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)tb = tt.find("name").textdf = tt.find("difficult").textaa = list([lefttopx, lefttopy, righttopx, righttopy, rightdownx, rightdowny, leftdownx, leftdowny, tb, df])bb = " ".join(aa)f.writelines(bb)f.writelines("\n")else:continue

替换:

/mmrotate/mmrotate/datasets/__init__.py

# Copyright (c) OpenMMLab. All rights reserved.

from .builder import build_dataset # noqa: F401, F403

from .dota import DOTADataset # noqa: F401, F403

from .hrsc import HRSCDataset # noqa: F401, F403

from .pipelines import * # noqa: F401, F403

from .sar import SARDataset # noqa: F401, F403

from .rolabel import ROLDataset # noqa: F401, F403

__all__ = ['SARDataset', 'DOTADataset', 'build_dataset', 'HRSCDataset', 'ROLDataset']

"/mmrotate/mmrotate/datasets/rolabel.py"

# Copyright (c) OpenMMLab. All rights reserved.

import glob

import os

import os.path as osp

import re

import tempfile

import time

import zipfile

from collections import defaultdict

from functools import partial

import mmcv

import numpy as np

import torch

from mmcv.ops import nms_rotated

from mmdet.datasets.custom import CustomDataset

from mmrotate.core import obb2poly_np, poly2obb_np

from mmrotate.core.evaluation import eval_rbbox_map

from .builder import ROTATED_DATASETS

@ROTATED_DATASETS.register_module()

class ROLDataset(CustomDataset):

"""DOTA dataset for detection.

Args:

ann_file (str): Annotation file path.

pipeline (list[dict]): Processing pipeline.

version (str, optional): Angle representations. Defaults to 'oc'.

difficulty (bool, optional): The difficulty threshold of GT.

"""

class_dir = 'yuml_web/data/class.txt'

with open(class_dir, "r", encoding="utf-8") as f:

a = [i.strip() for i in f.readlines()]

CLASSES = a

def __init__(self,

ann_file,

pipeline,

version='oc',

difficulty=100,

image_type='.png',

**kwargs):

self.version = version

self.difficulty = difficulty

self.image_type = image_type

super(ROLDataset, self).__init__(ann_file, pipeline, **kwargs)

def __len__(self):

"""Total number of samples of data."""

return len(self.data_infos)

def load_annotations(self, ann_folder):

"""

Args:

ann_folder: folder that contains DOTA v1 annotations txt files

"""

cls_map = {c: i

for i, c in enumerate(self.CLASSES)

} # in mmdet v2.0 label is 0-based

ann_files = glob.glob(ann_folder + '/*.txt')

data_infos = []

if not ann_files: # test phase

ann_files = glob.glob(ann_folder + '/*'+self.image_type)

for ann_file in ann_files:

data_info = {}

img_id = osp.split(ann_file)[1][:-4]

img_name = img_id + self.image_type

data_info['filename'] = img_name

data_info['ann'] = {}

data_info['ann']['bboxes'] = []

data_info['ann']['labels'] = []

data_infos.append(data_info)

else:

for ann_file in ann_files:

data_info = {}

img_id = osp.split(ann_file)[1][:-4]

img_name = img_id + self.image_type

data_info['filename'] = img_name

data_info['ann'] = {}

gt_bboxes = []

gt_labels = []

gt_polygons = []

gt_bboxes_ignore = []

gt_labels_ignore = []

gt_polygons_ignore = []

if os.path.getsize(ann_file) == 0:

continue

with open(ann_file) as f:

s = f.readlines()

for si in s:

bbox_info = si.split()

poly = np.array(bbox_info[:8], dtype=np.float32)

try:

x, y, w, h, a = poly2obb_np(poly, self.version)

except: # noqa: E722

continue

cls_name = bbox_info[8]

difficulty = int(bbox_info[9])

label = cls_map[cls_name]

if difficulty > self.difficulty:

pass

else:

gt_bboxes.append([x, y, w, h, a])

gt_labels.append(label)

gt_polygons.append(poly)

if gt_bboxes:

data_info['ann']['bboxes'] = np.array(

gt_bboxes, dtype=np.float32)

data_info['ann']['labels'] = np.array(

gt_labels, dtype=np.int64)

data_info['ann']['polygons'] = np.array(

gt_polygons, dtype=np.float32)

else:

data_info['ann']['bboxes'] = np.zeros((0, 5),

dtype=np.float32)

data_info['ann']['labels'] = np.array([], dtype=np.int64)

data_info['ann']['polygons'] = np.zeros((0, 8),

dtype=np.float32)

if gt_polygons_ignore:

data_info['ann']['bboxes_ignore'] = np.array(

gt_bboxes_ignore, dtype=np.float32)

data_info['ann']['labels_ignore'] = np.array(

gt_labels_ignore, dtype=np.int64)

data_info['ann']['polygons_ignore'] = np.array(

gt_polygons_ignore, dtype=np.float32)

else:

data_info['ann']['bboxes_ignore'] = np.zeros(

(0, 5), dtype=np.float32)

data_info['ann']['labels_ignore'] = np.array(

[], dtype=np.int64)

data_info['ann']['polygons_ignore'] = np.zeros(

(0, 8), dtype=np.float32)

data_infos.append(data_info)

self.img_ids = [*map(lambda x: x['filename'][:-4], data_infos)]

return data_infos

def _filter_imgs(self):

"""Filter images without ground truths."""

valid_inds = []

for i, data_info in enumerate(self.data_infos):

if data_info['ann']['labels'].size > 0:

valid_inds.append(i)

return valid_inds

def _set_group_flag(self):

"""Set flag according to image aspect ratio.

All set to 0.

"""

self.flag = np.zeros(len(self), dtype=np.uint8)

def evaluate(self,

results,

metric='mAP',

logger=None,

proposal_nums=(100, 300, 1000),

iou_thr=0.5,

scale_ranges=None,

nproc=4):

"""Evaluate the dataset.

Args:

results (list): Testing results of the dataset.

metric (str | list[str]): Metrics to be evaluated.

logger (logging.Logger | None | str): Logger used for printing

related information during evaluation. Default: None.

proposal_nums (Sequence[int]): Proposal number used for evaluating

recalls, such as recall@100, recall@1000.

Default: (100, 300, 1000).

iou_thr (float | list[float]): IoU threshold. It must be a float

when evaluating mAP, and can be a list when evaluating recall.

Default: 0.5.

scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.

Default: None.

nproc (int): Processes used for computing TP and FP.

Default: 4.

"""

nproc = min(nproc, os.cpu_count())

if not isinstance(metric, str):

assert len(metric) == 1

metric = metric[0]

allowed_metrics = ['mAP']

if metric not in allowed_metrics:

raise KeyError(f'metric {metric} is not supported')

annotations = [self.get_ann_info(i) for i in range(len(self))]

eval_results = {}

if metric == 'mAP':

assert isinstance(iou_thr, float)

mean_ap, _ = eval_rbbox_map(

results,

annotations,

scale_ranges=scale_ranges,

iou_thr=iou_thr,

dataset=self.CLASSES,

logger=logger,

nproc=nproc)

eval_results['mAP'] = mean_ap

else:

raise NotImplementedError

return eval_results

def merge_det(self, results, nproc=4):

"""Merging patch bboxes into full image.

Args:

results (list): Testing results of the dataset.

nproc (int): number of process. Default: 4.

"""

collector = defaultdict(list)

for idx in range(len(self)):

result = results[idx]

img_id = self.img_ids[idx]

splitname = img_id.split('__')

oriname = splitname[0]

pattern1 = pile(r'__\d+___\d+')

x_y = re.findall(pattern1, img_id)

x_y_2 = re.findall(r'\d+', x_y[0])

x, y = int(x_y_2[0]), int(x_y_2[1])

new_result = []

for i, dets in enumerate(result):

bboxes, scores = dets[:, :-1], dets[:, [-1]]

ori_bboxes = bboxes.copy()

ori_bboxes[..., :2] = ori_bboxes[..., :2] + np.array(

[x, y], dtype=np.float32)

labels = np.zeros((bboxes.shape[0], 1)) + i

new_result.append(

np.concatenate([labels, ori_bboxes, scores], axis=1))

new_result = np.concatenate(new_result, axis=0)

collector[oriname].append(new_result)

merge_func = partial(_merge_func, CLASSES=self.CLASSES, iou_thr=0.1)

if nproc <= 1:

print('Single processing')

merged_results = mmcv.track_iter_progress(

(map(merge_func, collector.items()), len(collector)))

else:

print('Multiple processing')

merged_results = mmcv.track_parallel_progress(

merge_func, list(collector.items()), nproc)

return zip(*merged_results)

def _results2submission(self, id_list, dets_list, out_folder=None):

"""Generate the submission of full images.

Args:

id_list (list): Id of images.

dets_list (list): Detection results of per class.

out_folder (str, optional): Folder of submission.

"""

if osp.exists(out_folder):

raise ValueError(f'The out_folder should be a non-exist path, '

f'but {out_folder} is existing')

os.makedirs(out_folder)

files = [

osp.join(out_folder, 'Task1_' + cls + '.txt')

for cls in self.CLASSES

]

file_objs = [open(f, 'w') for f in files]

for img_id, dets_per_cls in zip(id_list, dets_list):

for f, dets in zip(file_objs, dets_per_cls):

if dets.size == 0:

continue

bboxes = obb2poly_np(dets, self.version)

for bbox in bboxes:

txt_element = [img_id, str(bbox[-1])

] + [f'{p:.2f}' for p in bbox[:-1]]

f.writelines(' '.join(txt_element) + '\n')

for f in file_objs:

f.close()

target_name = osp.split(out_folder)[-1]

with zipfile.ZipFile(

osp.join(out_folder, target_name + '.zip'), 'w',

zipfile.ZIP_DEFLATED) as t:

for f in files:

t.write(f, osp.split(f)[-1])

return files

def format_results(self, results, submission_dir=None, nproc=4, **kwargs):

"""Format the results to submission text (standard format for DOTA

evaluation).

Args:

results (list): Testing results of the dataset.

submission_dir (str, optional): The folder that contains submission

files. If not specified, a temp folder will be created.

Default: None.

nproc (int, optional): number of process.

Returns:

tuple:

- result_files (dict): a dict containing the json filepaths

- tmp_dir (str): the temporal directory created for saving \

json files when submission_dir is not specified.

"""

nproc = min(nproc, os.cpu_count())

assert isinstance(results, list), 'results must be a list'

assert len(results) == len(self), (

f'The length of results is not equal to '

f'the dataset len: {len(results)} != {len(self)}')

if submission_dir is None:

submission_dir = tempfile.TemporaryDirectory()

else:

tmp_dir = None

print('\nMerging patch bboxes into full image!!!')

start_time = time.time()

id_list, dets_list = self.merge_det(results, nproc)

stop_time = time.time()

print(f'Used time: {(stop_time - start_time):.1f} s')

result_files = self._results2submission(id_list, dets_list,

submission_dir)

return result_files, tmp_dir

def _merge_func(info, CLASSES, iou_thr):

"""Merging patch bboxes into full image.

Args:

CLASSES (list): Label category.

iou_thr (float): Threshold of IoU.

"""

img_id, label_dets = info

label_dets = np.concatenate(label_dets, axis=0)

labels, dets = label_dets[:, 0], label_dets[:, 1:]

big_img_results = []

for i in range(len(CLASSES)):

if len(dets[labels == i]) == 0:

big_img_results.append(dets[labels == i])

else:

try:

cls_dets = torch.from_numpy(dets[labels == i]).cuda()

except: # noqa: E722

cls_dets = torch.from_numpy(dets[labels == i])

nms_dets, keep_inds = nms_rotated(cls_dets[:, :5], cls_dets[:, -1],

iou_thr)

big_img_results.append(nms_dets.cpu().numpy())

return img_id, big_img_results

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。