博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
python 离线数据增强(扩充数据集)
阅读量:3579 次
发布时间:2019-05-20

本文共 19990 字,大约阅读时间需要 66 分钟。

增强方式包括:#     (一) 针对像素的数据增强#     1. 改变亮度#     2. 加噪声#     (二) 针对图像的数据增强#     3. 裁剪(需改变bbox)#     4. 平移(需改变bbox)#     5. 镜像(需要改变bbox)#     6. 旋转(需要改变bbox)#     7. 遮挡

难点:博主在进行亮度、噪声、裁剪、平移、镜像、遮挡的实现时,还是比较轻松的;但是在旋转的实现中,label的调整陷入了麻烦,但是最后通过改进函数的方式;

最初:

# X_MIN = min(X1, X2, X3, X4)# X_MAX = max(X1, X2, X3, X4)# Y_MIN = min(Y1, Y2, Y3, Y4)# Y_MAX = max(Y1, Y2, Y3, Y4)

改进:

NEW_X1 = (X1 + X3) / 2            NEW_X2 = (X2 + X3) / 2            NEW_X3 = (X2 + X4) / 2            NEW_X4 = (X1 + X4) / 2            NEW_Y1 = (Y1 + Y3) / 2            NEW_Y2 = (Y2 + Y3) / 2            NEW_Y3 = (Y2 + Y4) / 2            NEW_Y4 = (Y1 + Y4) / 2            X_MIN = min(NEW_X1, NEW_X2, NEW_X3, NEW_X4)            X_MAX = max(NEW_X1, NEW_X2, NEW_X3, NEW_X4)            Y_MIN = min(NEW_Y1, NEW_Y2, NEW_Y3, NEW_Y4)            Y_MAX = max(NEW_Y1, NEW_Y2, NEW_Y3, NEW_Y4)

结果:旋转后的label可以很好的框住检测目标;

下面是完整代码:

# -*- coding=utf-8 -*-# -----------------------------------------------------------------# Description:#     data augmentation for obeject detection# 增强方式包括:#     (一) 针对像素的数据增强#     1. 改变亮度#     2. 加噪声#     (二) 针对图像的数据增强#     3. 裁剪(需改变bbox)#     4. 平移(需改变bbox)#     5. 镜像(需要改变bbox)#     6. 旋转(需要改变bbox)#     7. 遮挡# 注意:#     random.seed(),相同的seed,产生的随机数是一样的!!import timeimport randomimport cv2import osimport copyimport mathimport numpy as npfrom PIL import Imagefrom skimage.util import random_noisefrom skimage import exposureimport xml.etree.ElementTree as ETfrom xml.etree.ElementTree import ElementTree, Elementimport timeimport torchdef show_pic(img, bboxes=None):    '''    输入:        img:    图像array        bboxes: 图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....]        names:  每个box对应的名称    '''    cv2.imwrite('./1.jpg', img)    img = cv2.imread('./1.jpg')    for i in range(len(bboxes)):        bbox = bboxes[i]        x_min = bbox[0]        y_min = bbox[1]        x_max = bbox[2]        y_max = bbox[3]        cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), thickness=3)    cv2.namedWindow('pic', 0)  # 1表示原图    cv2.moveWindow('pic', 0, 0)  # 将显示窗口移到显示屏的相应位置    cv2.resizeWindow('pic', 1200, 800)  # 可视化的图片大小    cv2.imshow('pic', img)    cv2.waitKey(0)    cv2.destroyAllWindows()# 图像均为cv2读取class DataAugmentForObjectDetection():    def __init__(self, rotation_rate=0.5, max_rotation_angle=5,                 crop_rate=0.5, shift_rate=0.5, change_light_rate=0.5,                 add_noise_rate=0.5, flip_rate=0.5,                 erase_rate=0.5, erase_length=10, erase_holes=1, erase_threshold=0.5):        self.rotation_rate = rotation_rate        self.max_rotation_angle = max_rotation_angle        self.crop_rate = crop_rate        self.shift_rate = shift_rate        self.change_light_rate = change_light_rate        self.add_noise_rate = add_noise_rate        self.flip_rate = flip_rate        self.erase_rate = erase_rate        self.erase_length = erase_length        self.erase_holes = erase_holes        self.erase_threshold = erase_threshold    # 加噪声    def sp_noise(self,image, prob):        '''        添加椒盐噪声        prob:噪声比例        '''        output = np.zeros(image.shape, np.uint8)        thres = 1 - prob        for i in range(image.shape[0]):            for j in range(image.shape[1]):                rdn = random.random()                if rdn < prob:                    output[i][j] = 0                elif rdn > thres:                    output[i][j] = 255                else:                    output[i][j] = image[i][j]        return output    def gasuss_noise(self,image, mean=0, var=0.00001):        '''            添加高斯噪声            mean : 均值            var : 方差        '''        image = np.array(image / 255, dtype=float)        noise = np.random.normal(mean, var ** 0.5, image.shape)        out = image + noise        if out.min() < 0:            low_clip = -1.        else:            low_clip = 0.        out = np.clip(out, low_clip, 1.0)        out = np.uint8(out * 255)        # cv.imshow("gasuss", out)        return out    def addNoise(self, img):        '''        输入:            img:图像array        输出:            加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255        '''        # random.seed(int(time.time()))        # return random_noise(img, mode='gaussian', seed=int(time.time()), clip=True)*255        img = np.int(random_noise(img, mode='gaussian', clip=True, mean=0, var=0.00001)*255)        return img    # 调整亮度    def changeLight(self, img):        # random.seed(int(time.time()))        flag = random.uniform(0.5, 1.5)  # flag>1为调暗,小于1为调亮        img = np.abs(img)        # if np.min(img) < 0:        #     print(np.min(img))        return exposure.adjust_gamma(img, flag)    # 裁剪    def crop_img_bboxes(self, img, bboxes):        '''        裁剪后的图片要包含所有的框        输入:            img:图像array            bboxes:该图像包含的所有BBox,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值        输出:            crop_img:裁剪后的图像array            crop_bboxes:裁剪后的BBox的坐标list        '''        # ---------------------- 裁剪图像 ------------------------------        width = img.shape[1]        height = img.shape[0]        x_min = width        x_max = 0        y_min = height        y_max = 0        # 裁剪后的包含所有目标框的最小的框        for bbox in bboxes:            x_min = min(x_min, bbox[0])            y_min = min(y_min, bbox[1])            x_max = max(x_max, bbox[2])            y_max = max(y_max, bbox[3])        d_to_left = x_min  # 包含所有目标框的最小框到左边的距离        d_to_right = width - x_max  # 包含所有目标框的最小框到右边的距离        d_to_top = y_min  # 包含所有目标框的最小框到顶端的距离        d_to_bottom = height - y_max  # 包含所有目标框的最小框到底部的距离        # 随机扩展这个最小框        crop_x_min = int(x_min - random.uniform(0, d_to_left))        crop_y_min = int(y_min - random.uniform(0, d_to_top))        crop_x_max = int(x_max + random.uniform(0, d_to_right))        crop_y_max = int(y_max + random.uniform(0, d_to_bottom))        # 随机扩展这个最小框 , 防止别裁的太小        # crop_x_min = int(x_min - random.uniform(d_to_left//2, d_to_left))        # crop_y_min = int(y_min - random.uniform(d_to_top//2, d_to_top))        # crop_x_max = int(x_max + random.uniform(d_to_right//2, d_to_right))        # crop_y_max = int(y_max + random.uniform(d_to_bottom//2, d_to_bottom))        # 确保不要越界        crop_x_min = max(0, crop_x_min)        crop_y_min = max(0, crop_y_min)        crop_x_max = min(width, crop_x_max)        crop_y_max = min(height, crop_y_max)        crop_img = img[crop_y_min:crop_y_max, crop_x_min:crop_x_max]        # ---------------------- 裁剪BBox-----------------------        # 裁剪后的BBox坐标计算        crop_bboxes = list()        for bbox in bboxes:            crop_bboxes.append([bbox[0] - crop_x_min, bbox[1] - crop_y_min, bbox[2] - crop_x_min, bbox[3] - crop_y_min,bbox[4],bbox[4]])        return crop_img, crop_bboxes    # 平移    def shift_pic_bboxes(self, img, bboxes):        '''        平移后的图片要包含所有的框        输入:            img:图像array            bboxes:该图像包含的所有BBox,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值        输出:            shift_img:平移后的图像array            shift_bboxes:平移后的BBox的坐标list        '''        # ---------------------- 平移图像 ---------------------------        width = img.shape[1]        height = img.shape[0]        x_min = width        x_max = 0        y_min = height        y_max = 0        for bbox in bboxes:            x_min = min(x_min, bbox[0])  # bbox的x最小值小于width,x最大值大于0            y_min = min(y_min, bbox[1])            x_max = max(x_max, bbox[2])            y_max = max(y_max, bbox[3])        # x_min是所有目标框的x的最小值,y_min是所有目标框的y的最小值        # x_max是所有目标框的x的最大值,y_max是所有目标框的y的最大值        d_to_left = x_min  # 包含所有目标框的最大左移动距离        d_to_right = width - x_max  # 包含所有目标框的最大右移动距离        d_to_top = y_min  # 包含所有目标框的最大上移动距离        d_to_bottom = height - y_max  # 包含所有目标框的最大下移动距离        x = random.uniform(-(d_to_left - 1) / 3, (d_to_right - 1) / 3)        y = random.uniform(-(d_to_top - 1) / 3, (d_to_bottom - 1) / 3)        # x为向左或右移动的像素值,正为向右,负为向左;        # y为向上或者向下移动的像素值,正为向下,负为向上        M = np.float32([[1, 0, x], [0, 1, y]])        shift_img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))        # ---------------------- 平移BBox ----------------------        shift_bboxes = list()        for bbox in bboxes:            shift_bboxes.append([bbox[0] + x, bbox[1] + y, bbox[2] + x, bbox[3] + y,bbox[4]])        return shift_img, shift_bboxes    # 镜像    def filp_pic_bboxes(self, img, bboxes):        '''        输入:            img:图像array            bboxes:该图像包含的所有BBox,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值        输出:            flip_img:翻转后的图像array            flip_bboxes:翻转后的BBox的坐标list        '''        # ---------------------- 翻转图像 ----------------------        flip_img = copy.deepcopy(img)        if random.random() < 0.5:  # 0.5的概率水平翻转,0.5的概率垂直翻转            horizon = True        else:            horizon = False        height, width, _ = img.shape        if horizon:  # 水平翻转            flip_img = cv2.flip(flip_img, 1)  # 1是水平,-1是水平垂直        else:            flip_img = cv2.flip(flip_img, 0)        # ---------------------- 调整BBox ----------------------        flip_bboxes = list()        for box in bboxes:            x_min = box[0]            y_min = box[1]            x_max = box[2]            y_max = box[3]            if horizon:                flip_bboxes.append([width - x_max, y_min, width - x_min, y_max,box[4]])            else:                flip_bboxes.append([x_min, height - y_max, x_max, height - y_min,box[4]])        return flip_img, flip_bboxes    # 旋转    def rotate_img_bbox(self, img, bboxes, piangle=5, scale=1.):        '''        旋转后的图片需要包含所有的框,否则会对图像的原始标注造成破坏。        需要注意的是,旋转时图像的一些边角可能会被切除掉,需要避免这种情况。        关于仿射变换:        输入:            img:图像array,(h,w,c)            bboxes:该图像包含的所有BBox,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值            angle:旋转角度            scale:默认1        输出:            rot_img:旋转后的图像array            rot_bboxes:旋转后的BBox坐标list        '''        # ---------------------- 旋转图像 -----------------------------------        angle = -piangle * math.pi / 180.0        rows, cols = img.shape[:2]        a, b = cols / 2, rows / 2        M = cv2.getRotationMatrix2D((a, b), piangle, 1)        # img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)        rotated_img = cv2.warpAffine(img, M, (cols, rows))  # 旋转后的图像保持大小不变        # ---------------------- 矫正bbox坐标 ------------------------------        # rot_mat是最终的旋转矩阵        # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下        rot_bboxes = list()        for bbox in bboxes:            x1 = float(bbox[0])-1            y1 = float(bbox[1])-1            x2 = float(bbox[2])-1            y2 = float(bbox[3])-1            x3 = x1            y3 = y2            x4 = x2            y4 = y1            X1 = (x1 - a) * math.cos(angle) - (y1 - b) * math.sin(angle) + a            Y1 = (x1 - a) * math.sin(angle) + (y1 - b) * math.cos(angle) + b            X2 = (x2 - a) * math.cos(angle) - (y2 - b) * math.sin(angle) + a            Y2 = (x2 - a) * math.sin(angle) + (y2 - b) * math.cos(angle) + b            X3 = (x3 - a) * math.cos(angle) - (y3 - b) * math.sin(angle) + a            Y3 = (x3 - a) * math.sin(angle) + (y3 - b) * math.cos(angle) + b            X4 = (x4 - a) * math.cos(angle) - (y4 - b) * math.sin(angle) + a            Y4 = (x4 - a) * math.sin(angle) + (y4 - b) * math.cos(angle) + b            NEW_X1 = (X1 + X3) / 2            NEW_X2 = (X2 + X3) / 2            NEW_X3 = (X2 + X4) / 2            NEW_X4 = (X1 + X4) / 2            NEW_Y1 = (Y1 + Y3) / 2            NEW_Y2 = (Y2 + Y3) / 2            NEW_Y3 = (Y2 + Y4) / 2            NEW_Y4 = (Y1 + Y4) / 2            # X_MIN = min(X1, X2, X3, X4)            # X_MAX = max(X1, X2, X3, X4)            # Y_MIN = min(Y1, Y2, Y3, Y4)            # Y_MAX = max(Y1, Y2, Y3, Y4)            X_MIN = min(NEW_X1, NEW_X2, NEW_X3, NEW_X4)            X_MAX = max(NEW_X1, NEW_X2, NEW_X3, NEW_X4)            Y_MIN = min(NEW_Y1, NEW_Y2, NEW_Y3, NEW_Y4)            Y_MAX = max(NEW_Y1, NEW_Y2, NEW_Y3, NEW_Y4)            # if X1>X2:            #     X1,X2 = X2,X1            #     Y1,Y2 = Y2,Y1            #            # # 求中心点            # X_c = float(X1 + X2) / 2            # Y_c = float(Y1 + Y2) / 2            # H = Y2 - Y1            # W = X2 - X1            #            # X_MIN = X_c - W/2            # Y_MIN = Y_c - H/2            # X_MAX = X_c + W/2            # Y_MAX = Y_c + H/2            # 加入list中            rot_bboxes.append([X_MIN, Y_MIN, X_MAX, Y_MAX,bbox[4]])        return rotated_img, rot_bboxes    # 遮挡,擦除    def erase(self, img, bboxes, length=100, n_holes=1, threshold=0.5):        '''        Randomly mask out one or more patches from an image.        从图像中随机遮罩一个或多个面片。        Args:            img: a 3D numpy array,(h,w,c)            bboxes: 框的坐标            n_holes(int): Number of patches to cut out of each image.            length(int): The length (in pixels) of each square patch.        '''        def cal_iou(boxA, boxB):            '''            输入:boxA, boxB            输出:返回iou            '''            # 确定相交矩形的(x, y)-坐标            xA = max(boxA[0], boxB[0])            yA = max(boxA[1], boxB[1])            xB = min(boxA[2], boxB[2])            yB = min(boxA[3], boxB[3])            if xB <= xA or yB <= yA:                return 0.0            # 计算相交矩形的面积            interArea = (xB - xA + 1) * (yB - yA + 1)            # 计算prediction和ground-truth的面积            boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)            boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)            # iou = interArea / float(boxAArea + boxBArea - interArea)            iou = interArea / float(boxBArea)            return iou        # 得到h和w        if img.ndim == 3:            h, w, c = img.shape        else:            _, h, w, c = img.shape        mask = np.ones((h, w, c), np.float32)        for n in range(n_holes):            overlap = True  # 看切割的区域是否与box重叠太多            while overlap:                y = np.random.randint(h)                x = np.random.randint(w)                # numpy.clip(a, a_min, a_max, out=None), clip这个函数将将数组中的元素限制在a_min, a_max之间,                # 大于a_max的就使得它等于a_max,小于a_min的就使得它等于a_min                x1 = np.clip(x - length // 2, 0, w)                x2 = np.clip(x + length // 2, 0, w)                y1 = np.clip(y - length // 2, 0, h)                y2 = np.clip(y + length // 2, 0, h)                overlap = False                for box in bboxes:                    if cal_iou([x1, y1, x2, y2], box) > threshold:                        overlap = True                        break            mask[y1:y2, x1:x2, :] = 0.        # mask = np.expand_dims(mask, axis=0)        erase_img = img * mask        return erase_img, bboxes    def dataAugment(self, img, bboxes):        '''        图像增强        输入:            img:图像array            bboxes:该图像的所有框坐标        输出:            img:增强后的图像            bboxes:增强后图片对应的box        '''        change_num = random.sample(range(-6,0), 1)[0]  # 改变的次数        print('------------------开始进行数据增强-------------------')        while change_num < 1:  # 默认至少有一种数据增强生效            if random.random() < self.add_noise_rate:  # 加噪声                print('加噪声')                change_num += 1                img = self.gasuss_noise(img)                print(img.shape)            if random.random() > self.change_light_rate:  # 改变亮度                print('亮度')                change_num += 1                img = self.changeLight(img)                print(img.shape)            if random.random() < self.shift_rate:  # 平移                print('平移')                change_num += 1                img, bboxes = self.shift_pic_bboxes(img, bboxes)                print(img.shape)            if random.random() < self.flip_rate:  # 镜像                print('镜像')                change_num += 1                img, bboxes = self.filp_pic_bboxes(img, bboxes)                print(img.shape)            if random.random() > self.rotation_rate:  # 旋转                print('旋转')                change_num += 1                # angle = random.uniform(-self.max_rotation_angle, self.max_rotation_angle)                angle = random.sample(range(-180,180), 1)[0]                # angle = -45                print(angle)                scale = random.uniform(0.7, 0.8)                img, bboxes = self.rotate_img_bbox(img, bboxes, angle, scale)            if random.random() < self.erase_rate:  # 遮挡,擦除                print('遮挡,擦除')                change_num += 1                erase_length = random.sample(range(0,50),1)[0]                erase_holes = random.sample(range(0,5),1)[0]                img, bboxes = self.erase(img, bboxes, length=erase_length, n_holes=erase_holes,                                 threshold=self.erase_threshold)                print(img.shape)            print('\n')        print('------------------结束进行数据增强-------------------')        print(bboxes)        return img, bboxesdef xyxy2xywh(x):    dw = 1. / (608)    dh = 1. / (608)    # if w >= 1:    #     w = 0.99    # if h >= 1:    #     h = 0.99    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)    print(x.shape)    y[:, 0] = ((x[:, 0] + x[:, 2]) / 2 -1)*dw # x center    y[:, 1] = ((x[:, 1] + x[:, 3]) / 2  -1)*dw# y center    y[:, 2] = (x[:, 2] - x[:, 0])*dh  # width    y[:, 3] = (x[:, 3] - x[:, 1])*dh  # height    return yif __name__ == '__main__':    import shutil    from DOC import *    need_aug_num = 10    dataAug = DataAugmentForObjectDetection()    source_pic_root_path = r'C:\Users\JKY\Desktop\***' #your image folders    source_xml_root_path = r'C:\Users\JKY\Desktop\***' # your xml path    save_images = r'C:\Users\JKY\Desktop\DataAugmentation_ForObjectDetect-master\dataAugmentation-images' # image save path    save_labels = r'C:\Users\JKY\Desktop\DataAugmentation_ForObjectDetect-master\dataAugmentation-label' # label save path    for parent, _, files in os.walk(source_pic_root_path):        for file in files:            cnt = 0            while cnt < need_aug_num:                try:                    pic_path = os.path.join(parent, file)                    xml_path = os.path.join(source_xml_root_path, file[:-4] + '.xml')                    if not os.path.exists(xml_path):                        cnt += 1                        continue                    coords = parse_xml(xml_path)  # 解析得到box信息,格式为[[x_min,y_min,x_max,y_max,name]]                    # coords = [coord[:4] for coord in coords]                    print(pic_path)                    # img = cv2.imread(pic_path)                    img = Image.open(pic_path)                    img = np.array(img)                    # show_pic(img, coords)  # 原图                    auged_img, auged_bboxes = dataAug.dataAugment(img, coords)                    img = cv2.cvtColor(auged_img, cv2.COLOR_RGB2BGR)                    cnt += 1                    name = str(int(time.time() * 1e5))                    cv2.imwrite(os.path.join(save_images,name+'.jpg'),img)                    print('image save success')                    txt_path = os.path.join(save_labels,name+'.txt')                    auged_bboxes = np.float32(np.array(auged_bboxes))                    xywh = xyxy2xywh(auged_bboxes).astype(np.str).tolist()                    res_bboxes = []                    for i in range(len(xywh)):                        index = xywh[i][-1].index('.')                        xywhs = [xywh[i][-1][:index]] + xywh[i][:4]                        print(xywhs)                        res_bboxes.append(xywhs)                    res = []                    print(res_bboxes)                    for listi in res_bboxes:                        stri = ' '.join(listi)                        stri += '\n'                        res.append(stri)                    print(res_bboxes)                    with open(txt_path,'w+',-1) as file:                        file.writelines(res)                    file.close()                    # show_pic(auged_img, auged_bboxes)  # 数据增强后的图                except:                    cnt += 1                    continue

转载地址:http://huagj.baihongyu.com/

你可能感兴趣的文章
python学习8(列表)
查看>>
JavaScript学习(new1)
查看>>
http GET 和 POST 请求的优缺点、区别以及误区
查看>>
JVM的4种垃圾回收算法、垃圾回收机制
查看>>
什么是分布式事务
查看>>
常用的分布式事务解决方案
查看>>
设计模式:单例模式 (关于饿汉式和懒汉式)
查看>>
一致性Hash算法
查看>>
更新Navicat Premium 后打开数据库出现1146 - Table 'performance_schema.session_variables' doesn't exist
查看>>
安装rabbitmq时踩的坑
查看>>
2021-06-09数据库添加多条数据
查看>>
简单的JAVA小作品
查看>>
CMake下载
查看>>
未调用fflush产生的图片文件无法打开问题
查看>>
SQL 约束(二)
查看>>
SQL ALTER用法(三)
查看>>
SQL where子句及查询条件语句(六)
查看>>
SQL 连接JOIN(九)
查看>>
linux VM虚拟机可以ping通主机,但主机无法ping通虚拟机
查看>>
C++ 中Struct与typedef struct总结
查看>>