图像算法工程师 2025-04-10 09:12 采纳率: 26.7%
浏览 17

多类分割损失(gt_mask和pred_mask如何对应)问题

def cross_entropy_loss(
        inputs: torch.Tensor,
        targets: torch.Tensor,
        num_masks: float,
):
    """
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    Returns:
        Loss tensor
    """
    loss = F.cross_entropy(inputs, targets, reduction="none")
    loss = loss.mean(dim=(1, 2)).sum() / (num_masks + 1e-8)
    return loss

def dice_loss_multi_class(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    num_masks: float,
    num_classes: int,
    scale=1000,  # 100000.0,
    eps=1e-6,
):
    """
    Compute the DICE loss for multi-class masks.
    Args:
        inputs: A float tensor of shape (batch_size, num_classes, height, width).
                The predictions for each example.
        targets: A long tensor of shape (batch_size, height, width). Stores the class
                 labels for each element in inputs.
        num_classes: The number of classes.
    """
    inputs = inputs.softmax(dim=1)
    loss = 0
    print("inputs.shape",inputs.shape)

    for cls in range(num_classes):
        input_cls = inputs[:, cls, ...]
        target_cls = (targets == cls).float()
        numerator = 2 * (input_cls * target_cls).sum()
        denominator = (input_cls + target_cls).sum()
        loss += 1 - (numerator + eps) / (denominator + eps)
    loss = loss / num_classes
    loss = loss.sum() / (num_masks + 1e-8)
    return loss

sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
# low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)[:, :1]
low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)

pred_mask = self.postprocess_masks(
                low_res_masks,
                orig_hw=low_res_masks.shape,
            )
            pred_masks.append(pred_mask[:, 0])
        gt_masks = masks_list
        gt_mask_cpu = gt_masks.cpu()
        num_classes = np.unique(gt_mask_cpu.numpy())
        print(gt_masks.shape)
        print(pred_mask.shape)
        # pred_mask = pred_mask.unsqueeze(1)
        print("num_classes",len(num_classes))
        if inference:
            return {
                "pred_masks": pred_masks,
                "gt_masks": gt_masks,
            }

        mask_bce_loss = 0
        mask_dice_loss = 0
        num_masks = 0
        for batch_idx in range(1):
            gt_mask = gt_masks[:, :, :, :]
            pred_mask = pred_mask

            assert (
                    gt_mask.shape[0] == pred_mask.shape[0]
            ), "gt_mask.shape: {}, pred_mask.shape: {}".format(
                gt_mask.shape, pred_mask.shape
            )
            mask_bce_loss += (
                    cross_entropy_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
                    * gt_mask.shape[0]
            )
            # print(self.config.num_classes)
            # exit(0)
            mask_dice_loss += (
                    dice_loss_multi_class(pred_mask, gt_mask, num_masks=gt_mask.shape[0],num_classes=len(num_classes))
                    * gt_mask.shape[0]
            )
            num_masks += gt_mask.shape[0]

        mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (1 + 1e-8)
        mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (1 + 1e-8)
        mask_loss = mask_bce_loss + mask_dice_loss
        loss = mask_loss
###下面是我的dataloader代码
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None, mask_transform=None, text_transform=None):
        """
        Args:
            root_dir (str): 数据集的根目录路径。
            transform (callable, optional): 应用于图像的变换。
            mask_transform (callable, optional): 应用于掩码的变换。
            text_transform (callable, optional): 应用于文本的变换。
        """
        self.root_dir = root_dir
        self.images_dir = os.path.join(root_dir, 'images')
        self.masks_dir = os.path.join(root_dir, 'masks')
        self.texts_dir = os.path.join(root_dir, 'texts')
        self.transform = transform
        self.mask_transform = mask_transform
        self.text_transform = text_transform

        # 获取所有图像文件的名称(不包括扩展名)
        self.image_files = sorted([
            os.path.splitext(f)[0] for f in os.listdir(self.images_dir)
            if os.path.isfile(os.path.join(self.images_dir, f)) and f.lower().endswith('.jpg')
        ])

        # 确保每个图像都有对应的掩码和文本
        self.valid_files = []
        for base_name in self.image_files:
            mask_file = f"{base_name}.png"  # 掩码文件后缀为 .png
            text_file = f"{base_name}.txt"  # 文本文件后缀为 .txt

            mask_path = os.path.join(self.masks_dir, mask_file)
            text_path = os.path.join(self.texts_dir, text_file)

            if os.path.exists(mask_path) and os.path.exists(text_path):
                self.valid_files.append(base_name)
            else:
                missing = []
                if not os.path.exists(mask_path):
                    missing.append('mask')
                if not os.path.exists(text_path):
                    missing.append('text')
                missing_str = ' and '.join(missing)
                print(f"Warning: Missing {missing_str} for {base_name}")

    def __len__(self):
        return len(self.valid_files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        base_name = self.valid_files[idx]

        # 加载图像
        img_path = os.path.join(self.images_dir, base_name + '.jpg')  # 图像后缀为 .jpg
        image = Image.open(img_path).convert('RGB')
        image_tensor = transforms.ToTensor()(image)  # 将 PIL.Image 转换为 PyTorch 张量
        if torch.isnan(image_tensor).any():
            print(f"Warning: NaN found in image {base_name}")
            print(111111111111111111111111111111)
            exit(0)
        if self.transform:
            image = self.transform(image)

        # 加载掩码
        batch_masks_np = []
        mask_path = os.path.join(self.masks_dir, base_name + '.png')  # 掩码后缀为 .png
        mask = Image.open(mask_path).convert('L')  # 假设掩码是单通道
        print("mask.shape",mask.size)
        num_class = np.unique(mask)
        mask = torch.float32(0,len(num_class),mask.size)
        mask_one_hot = torch.nn.functional.one_hot(mask, num_classes=num_class)  # 形状为 (height, width, num_classes)
        # 调整维度顺序为 (batch_size, num_classes, height, width)
        # mask_one_hot = mask_one_hot.permute(2, 0, 1).unsqueeze(0).float()  # 添加 batch_size 维度,并转换为 float 类型
        # batch_masks_np.append(mask)
        # one_hot_masks_np = np.eye(num_class)[batch_masks_np]
        # one_hot_masks_np = one_hot_masks_np.transpose(0, 3, 1, 2)
        # one_hot_masks_tensor = torch.tensor(one_hot_masks_np, dtype=torch.float32)
        # mask = one_hot_masks_tensor
        # exit(0)
        mask_tensor = transforms.ToTensor()(mask)  # 将 PIL.Image 转换为 PyTorch 张量
        if torch.isnan(mask_tensor).any():
            print(f"Warning: NaN found in mask {base_name}")
            print(222222222222222222222222222222)
            exit(0)
        if self.mask_transform:
            print(88888888888888888888)
            mask = self.mask_transform(mask)
        else:
            mask = transforms.ToTensor()(mask)

        # 加载文本
        text_path = os.path.join(self.texts_dir, base_name + '.txt')
        with open(text_path, 'r', encoding='utf-8') as f:
            text = f.read()
        if not text:
            print(f"Warning: Empty text for {base_name}")
            print(3333333333333333333333333333333)
            exit(0)
        if self.text_transform:
            text = self.text_transform(text)

        sample = {
            'image': image,
            'mask': mask,
            'text': text,
            'name': base_name
        }
        # retun image,mask,text

        return sample

问题:我的掩码图有多类最多4类(值为0,1,2,3,0为背景),目前代码是只能支持单类分割,如何对CustomDataset进行修改,以适应多类别的分割任务,例如batchsize为4,其中一张图4类都有,另一张可能只包含2类等情况,如何对gt_mask和pred_mask进行处理并进行损失计算

  • 写回答

5条回答 默认 最新

  • 阿里嘎多学长 2025-04-10 09:12
    关注

    阿里嘎多学长整理AIGC生成,因移动端显示问题导致当前答案未能完全显示,请使用PC端查看更加详细的解答过程

    问题解答

    您的问题是关于多类分割损失函数(Cross-Entropy Loss)的实现,特别是关于gt_maskpred_mask的对应关系。

    在 PyTorch 中,cross_entropy_loss 函数用于计算分类损失。假设您已经定义了 inputstargets 两个张量,其中 inputs 是模型的输出,targets 是真实标签。

    gt_maskpred_mask 是用于 mask true labels 和 predictions 的 mask 张量。它们的作用是将不相关的类别设置为 0,使得损失函数只计算相关类别的损失。

    在计算损失时,您需要将 gt_maskpred_maskinputstargets 张量相乘,然后将结果传递给 cross_entropy_loss 函数。

    以下是一个示例代码:

    import torch
    import torch.nn as nn
    
    # 假设 inputs 和 targets 是 torch.Tensor
    inputs = torch.randn(1, 3, 224, 224)
    targets = torch.randint(0, 3, (1, 224, 224))
    
    # 创建 gt_mask 和 pred_mask
    gt_mask = torch.zeros_like(targets)
    gt_mask[targets == 0] = 1
    pred_mask = torch.sigmoid(inputs)
    
    # 计算损失
    loss = nn.CrossEntropyLoss()(inputs * pred_mask, targets * gt_mask)
    

    在上面的代码中,我们首先创建了 gt_maskpred_mask 张量,然后将它们和 inputstargets 张量相乘。最后,我们将结果传递给 CrossEntropyLoss 函数计算损失。

    注意:在实际实现中,您需要根据您的模型和数据集的具体情况进行调整。

    评论

报告相同问题?

问题事件

  • 创建了问题 4月10日