YOLOv5-7.0版本的算法中,数据增强方式采取了Mosaic方法,并且在Mosaic被采用的时候才可以调用mixup数据增强方法。那怎样可以用FMix算法替换mixup,以便使YOLOv5算法的数据增强效果更强大?
6条回答 默认 最新
- 「已注销」 2023-03-08 21:38关注
“Devil组”引证GPT后的撰写:
import torch import numpy as np import albumentations as A def fmix(x, y, alpha=1.0, decay_power=3.0, shape=(416, 416), max_soft=0.3, reformulate=False): lam, mask = sample_mask(alpha, decay_power, shape, max_soft) index = torch.randperm(x.size()[0]).cuda() xf = lam * x + (1 - lam) * x[index, :] yf = lam * y + (1 - lam) * y[index, :] if reformulate: yf[:, 0:4] = bbox_mashup(yf[:, 0:4], y[index, :, 0:4], mask) return xf, yf, lam def sample_mask(alpha, decay_power, shape, max_soft): H, W = shape lam, mask = np.random.beta(alpha, alpha), np.zeros((H, W)) indices = np.random.choice(np.arange(H*W), int(H*W*lam), replace=False) mask[np.unravel_index(indices, (H, W))] = 1 frequency = cv2.GaussianBlur(mask, (21, 21), 0) frequency = (frequency - frequency.min()) / (frequency.max() - frequency.min()) frequency = (max_soft - 0.1) * frequency + 0.1 return lam, frequency def bbox_mashup(src_bbox, dst_bbox, mask): src_bbox[:, :2] = mask * src_bbox[:, :2] + (1 - mask) * dst_bbox[:, :2] src_bbox[:, 2:4] = mask * src_bbox[:, 2:4] + (1 - mask) * dst_bbox[:, 2:4] return src_bbox class YOLOv5Dataset(torch.utils.data.Dataset): def __init__(self, data, img_size=416, transform=None, mosaic=False, mixup=False, fmix=False): self.img_files = [] self.label_files = [] self.img_size = img_size self.transform = transform self.mosaic = mosaic self.mixup = mixup self.fmix = fmix for d in data: if isinstance(d, str): if os.path.isdir(d): self.img_files += glob.glob(os.path.join(d, '*.jpg')) else: self.img_files.append(d) else: self.img_files.append(d[0]) self.label_files.append(d[1]) def __getitem__(self, index): if self.mixup: img, label = self._mixup(index) elif self.mosaic: img, label = self._mosaic(index) elif self.fmix: img, label = self._fmix(index) else: img, label = self._get_item(index) if self.transform: img, label = self.transform(img, label) return img, label def _get_item(self, index): img_path = self.img_files[index] img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w, _ = img.shape label_path = self.label_files[index] label = [] if os.path.exists(label_path): with open(label_path, 'r') as f: lines = f.readlines() for line in lines: line = line.strip() if len(line) > 0: class_id, x, y, w, h = line.split() x, y, w, h = float(x), float(y), float(w), float(h) label.append([x, y, w, h, int(class_id)]) if len(label) == 0: label.append([-1, -1, -1, -1, -1]) label = np.array(label) return img, label if fmix: x, y, lam = fmix(x, y)
本回答被题主选为最佳回答 , 对您是否有帮助呢?解决 无用评论 打赏 举报
悬赏问题
- ¥20 mysql架构,按照姓名分表
- ¥15 MATLAB实现区间[a,b]上的Gauss-Legendre积分
- ¥15 Macbookpro 连接热点正常上网,连接不了Wi-Fi。
- ¥15 delphi webbrowser组件网页下拉菜单自动选择问题
- ¥15 linux驱动,linux应用,多线程
- ¥20 我要一个分身加定位两个功能的安卓app
- ¥15 基于FOC驱动器,如何实现卡丁车下坡无阻力的遛坡的效果
- ¥15 IAR程序莫名变量多重定义
- ¥15 (标签-UDP|关键词-client)
- ¥15 关于库卡officelite无法与虚拟机通讯的问题