🤖딥러닝
CrossEntropyLoss Label Smoothing code
빅데희터
2024. 1. 30. 13:20
반응형
01. 클래스 별 weight 계산
# 클래스별 샘플 수
count_class_0 = 10000 #y_train에있는 0과1의 갯수
count_class_1 = 400
# 전체 샘플 수
total_count = count_class_0 + count_class_1
# 클래스별 가중치 계산
weight_class_0 = total_count / (2 * count_class_0)
weight_class_1 = total_count / (2 * count_class_1)
# 가중치 텐서 생성
weights = torch.tensor([weight_class_0, weight_class_1])
weights = weights.to(device)
02. Label Smoothing Loss function 정의
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules.loss import _WeightedLoss
class LabelSmoothingLoss(nn.Module):
def __init__(self, classes, smoothing=0.0, dim=-1, weight = None):
"""if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method
"""
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.weight = weight
self.cls = classes
self.dim = dim
def forward(self, pred, target):
assert 0 <= self.smoothing < 1
pred = pred.log_softmax(dim=self.dim)
if self.weight is not None:
pred = pred * self.weight.unsqueeze(0)
with torch.no_grad():
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.cls - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
loss_func = LabelSmoothingLoss(classes=2, smoothing=0.2, weight=weights)
참고문헌
https://stackoverflow.com/questions/55681502/label-smoothing-in-pytorch
반응형