200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > OpenPCDet 自定义数据集训练

OpenPCDet 自定义数据集训练

时间:2019-10-16 13:12:46

相关推荐

OpenPCDet 自定义数据集训练

目录

0、目标:

1、数据的预处理

2、修改数据处理部分的代码

2.1 复制对数据集进行处理的文件

2.2 对kitti_lidar_dataset.py进行修改

2.2.1 头文件修改

2.2.2 数据集对象名称修改

2.2.3get_info函数修改

2.2.4.yaml文件修改

2.2.5 运行

3、修改数据集加载

3.1 去掉测试

3.2 修改__getitem__函数

3.3 前后连起来

3.4 .yaml文件修改

3.5运行

·本文还存在错误,对点云并未进行坐标转化,选择性阅读

0、目标:

本文立足于pointpillars算法的训练,这里通过处理kitti数据集展示对自定义数据集的训练方法。

在源代码中对pointpillars的训练需要很多的数据(不晓得咋直接训练可以进入这篇博客OpenPCDet 在KITTI 训练PointPillar_辉e的博客-CSDN博客_openpcdet训练kitti)这里尤其是calib,我们对点云进行目标检测的训练,不需要啥坐标转换的,所以我这里想去除这个文件夹,只依靠velodyne和label来进行训练

1、数据的预处理

这里写了一个代码进行数据的预处理,其目的主要是对label的第12、13、14位进行处理,因为kitti数据集中这个标注的意思是在相机坐标系下其标注框的位置(x ,y ,z),而我们在使用过程中需要获得雷达坐标系下的标注,所以在这里进行预先的转化。

1、该代码写在tools文件夹中,kitti数据集在data文件夹中

2、运行下面的py文件会建立一个文件夹data/kitti/training/new_label_2,并将处理过然后产生的txt文件放入其中

3、运行完代码后将new_label_2名字改为label_2(原谅我是懒蛋,如果不改这个地方,会有很多其他地方要改)

import numpy as npfrom pathlib import Pathimport osdef get_calib_from_file(calib_file):with open(calib_file) as f:lines = f.readlines()obj = lines[2].strip().split(' ')[1:]P2 = np.array(obj, dtype=np.float32)obj = lines[3].strip().split(' ')[1:]P3 = np.array(obj, dtype=np.float32)obj = lines[4].strip().split(' ')[1:]R0 = np.array(obj, dtype=np.float32)obj = lines[5].strip().split(' ')[1:]Tr_velo_to_cam = np.array(obj, dtype=np.float32)return {'P2': P2.reshape(3, 4),'P3': P3.reshape(3, 4),'R0': R0.reshape(3, 3),'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4)}class Calibration(object):def __init__(self, calib_file):if not isinstance(calib_file, dict):calib = get_calib_from_file(calib_file)else:calib = calib_fileself.P2 = calib['P2'] # 3 x 4self.R0 = calib['R0'] # 3 x 3self.V2C = calib['Tr_velo2cam'] # 3 x 4# Camera intrinsics and extrinsicsself.cu = self.P2[0, 2]self.cv = self.P2[1, 2]self.fu = self.P2[0, 0]self.fv = self.P2[1, 1]self.tx = self.P2[0, 3] / (-self.fu)self.ty = self.P2[1, 3] / (-self.fv)def cart_to_hom(self, pts):""":param pts: (N, 3 or 2):return pts_hom: (N, 4 or 3)"""pts_hom = np.hstack((pts, np.ones((pts.shape[0], 1), dtype=np.float32)))return pts_hom#对R0_rect进行拓展,然后与Tr_velo_to_cam进行相乘求相反数后再求逆 R0_rect * Tr_velo_to_cam * y=x(y是雷达,x是照相机)def rect_to_lidar(self, pts_rect):""":param pts_lidar: (N, 3):return pts_rect: (N, 3)"""pts_rect_hom = self.cart_to_hom(pts_rect) # (N, 4)R0_ext = np.hstack((self.R0, np.zeros((3, 1), dtype=np.float32))) # (3, 4)R0_ext = np.vstack((R0_ext, np.zeros((1, 4), dtype=np.float32))) # (4, 4)R0_ext[3, 3] = 1V2C_ext = np.vstack((self.V2C, np.zeros((1, 4), dtype=np.float32))) # (4, 4)V2C_ext[3, 3] = 1pts_lidar = np.dot(pts_rect_hom, np.linalg.inv(np.dot(R0_ext, V2C_ext).T))return pts_lidar[:, 0:3]class Object3d(object):def __init__(self, line):label = line.strip().split(' ')self.top=np.array([])for i in range(0,11):self.top=np.append(self.top,label[i])self.loc = np.array((float(label[11]), float(label[12]), float(label[13])), dtype=np.float32)self.last=np.array([label[14]])def get_calib(root_split_path, idx):calib_file = root_split_path / 'calib' / ('%s.txt' % idx)assert calib_file.exists()return Calibration(calib_file)def get_objects_from_label(label_file):with open(label_file, 'r') as f:lines = f.readlines()objects = [Object3d(line) for line in lines]return objectsdef get_label(root_split_path, idx):label_file = root_split_path / 'label_2' / ('%s.txt' % idx)assert label_file.exists()return get_objects_from_label(label_file)def write_new_libel(root_split_path, idx, save_num):new_libel_file=root_split_path / 'new_label_2' / ('%s.txt' % idx)with open(new_libel_file, "a")as f:f.write(str(save_num[0]))for i in range(1,save_num.shape[0]):f.write(' '+str(save_num[i]))f.write('\r\n')#去掉文件最后的换行符def del_n(root_split_path,idx):new_libel_file=root_split_path / 'new_label_2' / ('%s.txt' % idx)file_object = open(new_libel_file, "rb+")file_object.seek(-2,2)file_object.truncate()file_object.close()def get_allfile(path): # 获取所有文件all_file = []files =sorted(os.listdir(path))for f in files : #listdir返回文件中所有目录#f_name = os.path.join(path, f)#f_name=os.path.basename(f_name)#去掉路径f=os.path.splitext(f)[0]#去掉文件名后缀all_file.append(f)return all_filedef clean_file(root_split_path,idx):new_libel_file=root_split_path / 'new_label_2' / ('%s.txt' % idx)file_object = open(new_libel_file, "w")file_object.close()def mkdir_new_label_2(root_split_path):new_libel_2=root_split_path / 'new_label_2'if os.path.exists(new_libel_2) is False:print("-------mkdir%s-------"%new_libel_2) os.mkdir(new_libel_2)root_split_path=Path('../data/kitti/training')mkdir_new_label_2(root_split_path)all_file=get_allfile(root_split_path/'label_2') #tickets要获取文件夹名print("-------All name loaded-------")#print(all_file)for file_idx in all_file:clean_file(root_split_path,file_idx)print("This is the %s.txt"%file_idx)calib=get_calib(root_split_path,file_idx)obj_list=get_label(root_split_path,file_idx)annotations = {}for obj in obj_list:annotations['location'] = np.concatenate([obj.loc.reshape(1, 3)], axis=0)#print(annotations['location'])loc_lidar = calib.rect_to_lidar(annotations['location'])loc_lidar=loc_lidar.reshape(-1)#print("top",obj.top[0])temp=np.concatenate([obj.top,loc_lidar,obj.last],axis=0)#print("concatenate",temp)write_new_libel(root_split_path, file_idx, temp)#del_n(root_split_path, file_idx)

2、修改数据处理部分的代码

OpenPCDet中首先对数据进行了一波预处理,我们仿照着写一下,这一步主要是对pcdet/datasets这个文件夹进行处理

2.1 复制对数据集进行处理的文件

把pcdet/datasets/kitti文件夹复制并改名为pcdet/datasets/kitti_lidar,然后把pcdet/utils/object3d_kitti.py复制为pcdet/utils/object3d_kitti_lidar.py

2.2 对kitti_lidar_dataset.py进行修改

pcdet/datasets/kitti_lidar/kitti_lidar_dataset.py

2.2.1 头文件修改

这一行修改最后的object3d_kitti为object3d_kitti_lidar

from ...utils import box_utils, calibration_kitti, common_utils, object3d_kitti_lidar

2.2.2 数据集对象名称修改

头文件下面一行修改为(原类名为KittiDataset)

class KittiLidarDataset(DatasetTemplate):

2.2.3get_info函数修改

这里其他的地方不要改,直接到这个函数,然后替换为下面的代码

def get_infos(self, num_workers=4, has_label=True, count_inside_pts=True, sample_id_list=None):import concurrent.futures as futuresdef process_single_scene(sample_idx):print('%s sample_idx: %s' % (self.split, sample_idx))info = {}pc_info = {'num_features': 4, 'lidar_idx': sample_idx}info['point_cloud'] = pc_infoif has_label:obj_list = self.get_label(sample_idx)annotations = {}annotations['name'] = np.array([obj.cls_type for obj in obj_list])annotations['truncated'] = np.array([obj.truncation for obj in obj_list])annotations['occluded'] = np.array([obj.occlusion for obj in obj_list])annotations['alpha'] = np.array([obj.alpha for obj in obj_list])annotations['bbox'] = np.concatenate([obj.box2d.reshape(1, 4) for obj in obj_list], axis=0)annotations['dimensions'] = np.array([[obj.l, obj.h, obj.w] for obj in obj_list]) # lhw(camera) formatannotations['location'] = np.concatenate([obj.loc.reshape(1, 3) for obj in obj_list], axis=0)annotations['rotation_y'] = np.array([obj.ry for obj in obj_list])annotations['score'] = np.array([obj.score for obj in obj_list])annotations['difficulty'] = np.array([obj.level for obj in obj_list], np.int32)num_objects = len([obj.cls_type for obj in obj_list if obj.cls_type != 'DontCare'])num_gt = len(annotations['name'])index = list(range(num_objects)) + [-1] * (num_gt - num_objects)annotations['index'] = np.array(index, dtype=np.int32)loc = annotations['location'][:num_objects]dims = annotations['dimensions'][:num_objects]rots = annotations['rotation_y'][:num_objects]#loc_lidar = calib.rect_to_lidar(loc)#获得一个变换矩阵loc_lidar=locl, h, w = dims[:, 0:1], dims[:, 1:2], dims[:, 2:3]loc_lidar[:, 2] += h[:, 0] / 2gt_boxes_lidar = np.concatenate([loc_lidar, l, w, h, -(np.pi / 2 + rots[..., np.newaxis])], axis=1)annotations['gt_boxes_lidar'] = gt_boxes_lidarinfo['annos'] = annotationsreturn infosample_id_list = sample_id_list if sample_id_list is not None else self.sample_id_listwith futures.ThreadPoolExecutor(num_workers) as executor:infos = executor.map(process_single_scene, sample_id_list)return list(infos)

2.2.4.yaml文件修改

老规矩,先cv,将tools/cfgs/dataset_configs/kitti_dataset.yaml复制为tools/cfgs/dataset_configs/kitti_lidar.yaml。然后修改一下第一行,修改为

DATASET: 'KittiLidarDataset'

2.2.5 运行

终端输入

python -m pcdet.datasets.kitti_lidar.kitti_lidar_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_lidar.yaml

结果展示:

然后我们的pkl文件就存放在data/kitti里面啦

3、修改数据集加载

3.0 复制数据集加载的文件

把pcdet/datasets/kitti_lidar文件夹复制并改名为pcdet/datasets/kitti_lidar,里面的文件相应改名

3.1 去掉测试

tools/train.py这个文件夹内部,去掉测试的代码,我们少修改一点

"""logger.info('**********************Start evaluation %s/%s(%s)**********************' %(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))test_set, test_loader, sampler = build_dataloader(dataset_cfg=cfg.DATA_CONFIG,class_names=cfg.CLASS_NAMES,batch_size=args.batch_size,dist=dist_train, workers=args.workers, logger=logger, training=False)eval_output_dir = output_dir / 'eval' / 'eval_with_train'eval_output_dir.mkdir(parents=True, exist_ok=True)args.start_epoch = max(args.epochs - args.num_epochs_to_eval, 0) # Only evaluate the last args.num_epochs_to_eval epochsrepeat_eval_ckpt(model.module if dist_train else model,test_loader, args, eval_output_dir, logger, ckpt_dir,dist_test=dist_train)logger.info('**********************End evaluation %s/%s(%s)**********************' %(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))"""

3.2 修改__getitem__函数

pcdet/datasets/kitti_lidar/kitti_lidar_dataset.py,修改数据加载的文件,这里主要把图像和calib的加载去掉,然后把我们新的数据集文件(label_2)导入

def __getitem__(self, index):# index = 4if self._merge_all_iters_to_one_epoch:index = index % len(self.kitti_infos)info = copy.deepcopy(self.kitti_infos[index])sample_idx = info['point_cloud']['lidar_idx']#img_shape = info['image']['image_shape']#calib = self.get_calib(sample_idx)get_item_list = self.dataset_cfg.get('GET_ITEM_LIST', ['points'])input_dict = {'frame_id': sample_idx,#'calib': calib,}if 'annos' in info:annos = info['annos']annos = common_utils.drop_info_with_name(annos, name='DontCare')loc, dims, rots = annos['location'], annos['dimensions'], annos['rotation_y']gt_names = annos['name']#gt_boxes_camera = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1).astype(np.float32)gt_boxes_lidar = annos['gt_boxes_lidar']input_dict.update({'gt_names': gt_names,'gt_boxes': gt_boxes_lidar})#if "gt_boxes2d" in get_item_list:# input_dict['gt_boxes2d'] = annos["bbox"]road_plane = self.get_road_plane(sample_idx)if road_plane is not None:input_dict['road_plane'] = road_planeif "points" in get_item_list:points = self.get_lidar(sample_idx)input_dict['points'] = pointsdata_dict = self.prepare_data(data_dict=input_dict)#data_dict['image_shape'] = img_shapereturn data_dict

3.3 前后连起来

pcdet/datasets/__init__.py,将前面的部分和后面的部分连起来

#头文件中加入,我们2.2.2from .kitti_lidar.kitti_lidar_dataset import KittiLidarDataset__all__ = {'DatasetTemplate': DatasetTemplate,'KittiDataset': KittiDataset,'KittiLidarDataset':KittiLidarDataset,#相应的这里也加入'NuScenesDataset': NuScenesDataset,'WaymoDataset': WaymoDataset,'PandasetDataset': PandasetDataset,'LyftDataset': LyftDataset}

3.4 .yaml文件修改

将tools/cfgs/kitti_models/pointpillar.yaml复制到tools/cfgs/kitti_lidar_models/pointpillar.yaml,kitti_lidar_models这个文件夹自己建立

其中修改_BASE_CONFIG_

DATA_CONFIG: _BASE_CONFIG_: cfgs/dataset_configs/kitti_lidar.yaml

3.5运行

cd toolspython train.py --cfg_file=cfgs/kitti_lidar_models/pointpillar.yaml --batch_size=3 --epochs=100

运行结果:

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。