def cross_entropy_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
Loss tensor
"""
loss = F.cross_entropy(inputs, targets, reduction="none")
loss = loss.mean(dim=(1, 2)).sum() / (num_masks + 1e-8)
return loss
def dice_loss_multi_class(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
num_classes: int,
scale=1000, # 100000.0,
eps=1e-6,
):
"""
Compute the DICE loss for multi-class masks.
Args:
inputs: A float tensor of shape (batch_size, num_classes, height, width).
The predictions for each example.
targets: A long tensor of shape (batch_size, height, width). Stores the class
labels for each element in inputs.
num_classes: The number of classes.
"""
inputs = inputs.softmax(dim=1)
loss = 0
print("inputs.shape",inputs.shape)
for cls in range(num_classes):
input_cls = inputs[:, cls, ...]
target_cls = (targets == cls).float()
numerator = 2 * (input_cls * target_cls).sum()
denominator = (input_cls + target_cls).sum()
loss += 1 - (numerator + eps) / (denominator + eps)
loss = loss / num_classes
loss = loss.sum() / (num_masks + 1e-8)
return loss
sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
# low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)[:, :1]
low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)
pred_mask = self.postprocess_masks(
low_res_masks,
orig_hw=low_res_masks.shape,
)
pred_masks.append(pred_mask[:, 0])
gt_masks = masks_list
gt_mask_cpu = gt_masks.cpu()
num_classes = np.unique(gt_mask_cpu.numpy())
print(gt_masks.shape)
print(pred_mask.shape)
# pred_mask = pred_mask.unsqueeze(1)
print("num_classes",len(num_classes))
if inference:
return {
"pred_masks": pred_masks,
"gt_masks": gt_masks,
}
mask_bce_loss = 0
mask_dice_loss = 0
num_masks = 0
for batch_idx in range(1):
gt_mask = gt_masks[:, :, :, :]
pred_mask = pred_mask
assert (
gt_mask.shape[0] == pred_mask.shape[0]
), "gt_mask.shape: {}, pred_mask.shape: {}".format(
gt_mask.shape, pred_mask.shape
)
mask_bce_loss += (
cross_entropy_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
* gt_mask.shape[0]
)
# print(self.config.num_classes)
# exit(0)
mask_dice_loss += (
dice_loss_multi_class(pred_mask, gt_mask, num_masks=gt_mask.shape[0],num_classes=len(num_classes))
* gt_mask.shape[0]
)
num_masks += gt_mask.shape[0]
mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (1 + 1e-8)
mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (1 + 1e-8)
mask_loss = mask_bce_loss + mask_dice_loss
loss = mask_loss
###下面是我的dataloader代码
class CustomDataset(Dataset):
def __init__(self, root_dir, transform=None, mask_transform=None, text_transform=None):
"""
Args:
root_dir (str): 数据集的根目录路径。
transform (callable, optional): 应用于图像的变换。
mask_transform (callable, optional): 应用于掩码的变换。
text_transform (callable, optional): 应用于文本的变换。
"""
self.root_dir = root_dir
self.images_dir = os.path.join(root_dir, 'images')
self.masks_dir = os.path.join(root_dir, 'masks')
self.texts_dir = os.path.join(root_dir, 'texts')
self.transform = transform
self.mask_transform = mask_transform
self.text_transform = text_transform
# 获取所有图像文件的名称(不包括扩展名)
self.image_files = sorted([
os.path.splitext(f)[0] for f in os.listdir(self.images_dir)
if os.path.isfile(os.path.join(self.images_dir, f)) and f.lower().endswith('.jpg')
])
# 确保每个图像都有对应的掩码和文本
self.valid_files = []
for base_name in self.image_files:
mask_file = f"{base_name}.png" # 掩码文件后缀为 .png
text_file = f"{base_name}.txt" # 文本文件后缀为 .txt
mask_path = os.path.join(self.masks_dir, mask_file)
text_path = os.path.join(self.texts_dir, text_file)
if os.path.exists(mask_path) and os.path.exists(text_path):
self.valid_files.append(base_name)
else:
missing = []
if not os.path.exists(mask_path):
missing.append('mask')
if not os.path.exists(text_path):
missing.append('text')
missing_str = ' and '.join(missing)
print(f"Warning: Missing {missing_str} for {base_name}")
def __len__(self):
return len(self.valid_files)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
base_name = self.valid_files[idx]
# 加载图像
img_path = os.path.join(self.images_dir, base_name + '.jpg') # 图像后缀为 .jpg
image = Image.open(img_path).convert('RGB')
image_tensor = transforms.ToTensor()(image) # 将 PIL.Image 转换为 PyTorch 张量
if torch.isnan(image_tensor).any():
print(f"Warning: NaN found in image {base_name}")
print(111111111111111111111111111111)
exit(0)
if self.transform:
image = self.transform(image)
# 加载掩码
batch_masks_np = []
mask_path = os.path.join(self.masks_dir, base_name + '.png') # 掩码后缀为 .png
mask = Image.open(mask_path).convert('L') # 假设掩码是单通道
print("mask.shape",mask.size)
num_class = np.unique(mask)
mask = torch.float32(0,len(num_class),mask.size)
mask_one_hot = torch.nn.functional.one_hot(mask, num_classes=num_class) # 形状为 (height, width, num_classes)
# 调整维度顺序为 (batch_size, num_classes, height, width)
# mask_one_hot = mask_one_hot.permute(2, 0, 1).unsqueeze(0).float() # 添加 batch_size 维度,并转换为 float 类型
# batch_masks_np.append(mask)
# one_hot_masks_np = np.eye(num_class)[batch_masks_np]
# one_hot_masks_np = one_hot_masks_np.transpose(0, 3, 1, 2)
# one_hot_masks_tensor = torch.tensor(one_hot_masks_np, dtype=torch.float32)
# mask = one_hot_masks_tensor
# exit(0)
mask_tensor = transforms.ToTensor()(mask) # 将 PIL.Image 转换为 PyTorch 张量
if torch.isnan(mask_tensor).any():
print(f"Warning: NaN found in mask {base_name}")
print(222222222222222222222222222222)
exit(0)
if self.mask_transform:
print(88888888888888888888)
mask = self.mask_transform(mask)
else:
mask = transforms.ToTensor()(mask)
# 加载文本
text_path = os.path.join(self.texts_dir, base_name + '.txt')
with open(text_path, 'r', encoding='utf-8') as f:
text = f.read()
if not text:
print(f"Warning: Empty text for {base_name}")
print(3333333333333333333333333333333)
exit(0)
if self.text_transform:
text = self.text_transform(text)
sample = {
'image': image,
'mask': mask,
'text': text,
'name': base_name
}
# retun image,mask,text
return sample
问题:我的掩码图有多类最多4类(值为0,1,2,3,0为背景),目前代码是只能支持单类分割,如何对CustomDataset进行修改,以适应多类别的分割任务,例如batchsize为4,其中一张图4类都有,另一张可能只包含2类等情况,如何对gt_mask和pred_mask进行处理并进行损失计算