🤖딥러닝

CrossEntropyLoss Label Smoothing code

빅데희터 2024. 1. 30. 13:20
반응형

01. 클래스 별 weight 계산

# 클래스별 샘플 수
count_class_0 = 10000  #y_train에있는 0과1의 갯수
count_class_1 = 400

# 전체 샘플 수
total_count = count_class_0 + count_class_1

# 클래스별 가중치 계산
weight_class_0 = total_count / (2 * count_class_0)
weight_class_1 = total_count / (2 * count_class_1)

# 가중치 텐서 생성
weights = torch.tensor([weight_class_0, weight_class_1])
weights = weights.to(device)

 

 

 

02. Label Smoothing Loss function 정의

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules.loss import _WeightedLoss

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1, weight = None):
        """if smoothing == 0, it's one-hot method
           if 0 < smoothing < 1, it's smooth method
        """
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.weight = weight
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        assert 0 <= self.smoothing < 1
        pred = pred.log_softmax(dim=self.dim)

        if self.weight is not None:
            pred = pred * self.weight.unsqueeze(0)   

        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
        
     
loss_func = LabelSmoothingLoss(classes=2, smoothing=0.2, weight=weights)

 

 

 

 

참고문헌

https://stackoverflow.com/questions/55681502/label-smoothing-in-pytorch

반응형