KimbgAI

[ML] Dice loss & Dice Score with monai, pytorch 본문

machine learning

[ML] Dice loss & Dice Score with monai, pytorch

KimbgAI 2022. 11. 8. 16:50
반응형

Dice는 특이하게 loss function으로도 쓰이지만, metric으로도 사용된다.
(그 이유는 segmentation이라는 task의 특수성 때문인데, pixel별로 class를 예측하기 때문에 metric score의 변화가 상당히 연속적이라 대부분의 구간에서 미분값이 유의미하기 때문이다.)
(classification task에서 accuracy를 가지고 loss function으로 사용하지 않는 이유와 반대로 비슷하게 생각하면 된다.)

한편, 보통 Dice loss 는 1 - dice metric 으로 정의된다.

Dice 지표의 계산방법 (Dice는 사람이름에서 따왔다고 한다)

보시다시피 F1 score와 같다. 민감도(sensitivity)와 정밀도(precision)의 조화평균이다.

monai에서 dice loss와 score를 편하게 계산할 수 있다.

어떻게 작동하는지, 잘 작동되는지 간단하게 확인해보았다.

Dice loss

import torch
from monai.losses.dice import DiceLoss, one_hot

B, C, H, W = 7, 5, 32, 32
output = torch.rand(B, C, H, W) # [0, 1)
target_idx = torch.randint(low=0, high=C, size=(B, H, W)).long()
target = one_hot(target_idx[:, None, ...], num_classes=C)

print('output :', output.shape) 
print('target_idx :', target_idx.shape) # not one hot
print('target :', target.shape) # one hot

dice = DiceLoss(reduction='mean')
loss = dice(output, target)
print(loss)

dice = DiceLoss(include_background=False, reduction='mean')
loss = dice(output, target)
print(loss)
output : torch.Size([7, 5, 32, 32])
target_idx : torch.Size([7, 32, 32])
target : torch.Size([7, 5, 32, 32])
tensor(0.7141)
tensor(0.7150)

계산하기 위해서는 output과 target shape이 같아야한다.
간편하게 background를 loss에 포함여부를 설정할 수 있다.


Dice score

import torch
from monai.losses.dice import DiceLoss, one_hot
from monai.metrics import DiceMetric

B, C, H, W = 7, 5, 32, 32
output = torch.rand(B, C, H, W) # [0, 1)
target_idx = torch.randint(low=0, high=C, size=(B, H, W)).long()
target = one_hot(target_idx[:, None, ...], num_classes=C)

print('output :', output.shape) 
print('target_idx :', target_idx.shape) # not one hot
print('target :', target.shape) # one hot
print('='*50)

dice = DiceMetric(include_background=True, reduction='mean')
print(dice(output, target))
print(dice.aggregate())
print(dice(output, target).mean())
print('='*50)

dice = DiceMetric(include_background=False, reduction='mean')
print(dice(output, target))
print(dice.aggregate())
print(dice(output, target).mean())
output : torch.Size([7, 5, 32, 32])
target_idx : torch.Size([7, 32, 32])
target : torch.Size([7, 5, 32, 32])
==================================================
tensor([[0.2950, 0.2662, 0.2904, 0.2866, 0.2943],
        [0.2888, 0.2813, 0.2727, 0.2645, 0.2868],
        [0.2831, 0.2667, 0.2772, 0.3089, 0.2676],
        [0.2575, 0.2709, 0.3029, 0.2869, 0.2900],
        [0.2921, 0.2689, 0.2754, 0.2876, 0.2775],
        [0.2885, 0.2669, 0.2779, 0.2746, 0.3036],
        [0.2798, 0.2904, 0.2803, 0.2867, 0.2997]])
tensor([0.2825])
tensor(0.2825)
==================================================
tensor([[0.2662, 0.2904, 0.2866, 0.2943],
        [0.2813, 0.2727, 0.2645, 0.2868],
        [0.2667, 0.2772, 0.3089, 0.2676],
        [0.2709, 0.3029, 0.2869, 0.2900],
        [0.2689, 0.2754, 0.2876, 0.2775],
        [0.2669, 0.2779, 0.2746, 0.3036],
        [0.2904, 0.2803, 0.2867, 0.2997]])
tensor([0.2823])
tensor(0.2823)

단순히 dice(output, target) 의 결과는 텐서형태로 나타나고,
이는 아웃풋의 각 채널(클래스)에 해당하는 dice score에 해당한다.
전체 클래스의 평균적인 dice score를 보기 위해서는 aggregate 또는 mean 해주면 된다.

보는 것과 같이 include_background 여부에 따라 score가 달라지는 걸 볼 수 있다.
또, background에 해당되는 0번째 index 없이 계산된 것을 볼 수 있다.



끝!

반응형
Comments