일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
- nibabel
- MRI
- TabNet
- 코드오류
- Phase recognition
- 확산텐서영상
- decorater
- monai
- 확산강조영상
- tabular
- 유전역학
- words encoding
- paper review
- Surgical video analysis
- nfiti
- 모수적 모델
- 비모수적 모델
- PYTHON
- precision #정밀도 #민감도 #sensitivity #특이도 #specifisity #F1 score #dice score #confusion matrix #recall #PR-AUC #ROC-AUC #PR curve #ROC curve #NPV #PPV
- nlp
- genetic epidemiology
- parametric model
- 파이썬
- MICCAI
- non-parametric model
- deep learning #segmentation #sementic #pytorch #UNETR #transformer #UNET #3D #3D medical image
- parer review
- parrec
- 데코레이터
- TeCNO
- Today
- Total
KimbgAI
[ML] Dice loss & Dice Score with monai, pytorch 본문
Dice는 특이하게 loss function으로도 쓰이지만, metric으로도 사용된다.
(그 이유는 segmentation이라는 task의 특수성 때문인데, pixel별로 class를 예측하기 때문에 metric score의 변화가 상당히 연속적이라 대부분의 구간에서 미분값이 유의미하기 때문이다.)
(classification task에서 accuracy를 가지고 loss function으로 사용하지 않는 이유와 반대로 비슷하게 생각하면 된다.)
한편, 보통 Dice loss 는 1 - dice metric 으로 정의된다.
보시다시피 F1 score와 같다. 민감도(sensitivity)와 정밀도(precision)의 조화평균이다.
monai에서 dice loss와 score를 편하게 계산할 수 있다.
어떻게 작동하는지, 잘 작동되는지 간단하게 확인해보았다.
Dice loss
import torch
from monai.losses.dice import DiceLoss, one_hot
B, C, H, W = 7, 5, 32, 32
output = torch.rand(B, C, H, W) # [0, 1)
target_idx = torch.randint(low=0, high=C, size=(B, H, W)).long()
target = one_hot(target_idx[:, None, ...], num_classes=C)
print('output :', output.shape)
print('target_idx :', target_idx.shape) # not one hot
print('target :', target.shape) # one hot
dice = DiceLoss(reduction='mean')
loss = dice(output, target)
print(loss)
dice = DiceLoss(include_background=False, reduction='mean')
loss = dice(output, target)
print(loss)
output : torch.Size([7, 5, 32, 32])
target_idx : torch.Size([7, 32, 32])
target : torch.Size([7, 5, 32, 32])
tensor(0.7141)
tensor(0.7150)
계산하기 위해서는 output과 target shape이 같아야한다.
간편하게 background를 loss에 포함여부를 설정할 수 있다.
Dice score
import torch
from monai.losses.dice import DiceLoss, one_hot
from monai.metrics import DiceMetric
B, C, H, W = 7, 5, 32, 32
output = torch.rand(B, C, H, W) # [0, 1)
target_idx = torch.randint(low=0, high=C, size=(B, H, W)).long()
target = one_hot(target_idx[:, None, ...], num_classes=C)
print('output :', output.shape)
print('target_idx :', target_idx.shape) # not one hot
print('target :', target.shape) # one hot
print('='*50)
dice = DiceMetric(include_background=True, reduction='mean')
print(dice(output, target))
print(dice.aggregate())
print(dice(output, target).mean())
print('='*50)
dice = DiceMetric(include_background=False, reduction='mean')
print(dice(output, target))
print(dice.aggregate())
print(dice(output, target).mean())
output : torch.Size([7, 5, 32, 32])
target_idx : torch.Size([7, 32, 32])
target : torch.Size([7, 5, 32, 32])
==================================================
tensor([[0.2950, 0.2662, 0.2904, 0.2866, 0.2943],
[0.2888, 0.2813, 0.2727, 0.2645, 0.2868],
[0.2831, 0.2667, 0.2772, 0.3089, 0.2676],
[0.2575, 0.2709, 0.3029, 0.2869, 0.2900],
[0.2921, 0.2689, 0.2754, 0.2876, 0.2775],
[0.2885, 0.2669, 0.2779, 0.2746, 0.3036],
[0.2798, 0.2904, 0.2803, 0.2867, 0.2997]])
tensor([0.2825])
tensor(0.2825)
==================================================
tensor([[0.2662, 0.2904, 0.2866, 0.2943],
[0.2813, 0.2727, 0.2645, 0.2868],
[0.2667, 0.2772, 0.3089, 0.2676],
[0.2709, 0.3029, 0.2869, 0.2900],
[0.2689, 0.2754, 0.2876, 0.2775],
[0.2669, 0.2779, 0.2746, 0.3036],
[0.2904, 0.2803, 0.2867, 0.2997]])
tensor([0.2823])
tensor(0.2823)
단순히 dice(output, target) 의 결과는 텐서형태로 나타나고,
이는 아웃풋의 각 채널(클래스)에 해당하는 dice score에 해당한다.
전체 클래스의 평균적인 dice score를 보기 위해서는 aggregate 또는 mean 해주면 된다.
보는 것과 같이 include_background 여부에 따라 score가 달라지는 걸 볼 수 있다.
또, background에 해당되는 0번째 index 없이 계산된 것을 볼 수 있다.
끝!
'machine learning' 카테고리의 다른 글
[ML] Data augmentation for 3D medical image (3) | 2022.11.17 |
---|---|
[ML] ViT(20.10); Vision Transformer 코드 구현 및 설명 with pytorch (2) | 2022.11.10 |
[ML][pytorch] torch.nn 과 torch.nn.Functional 의 차이 (0) | 2022.11.08 |
FLOPs란? 딥러닝 연산량에 대해서.. (2) | 2022.11.03 |
[ML] VNet(16.06) 요약 및 코드 구현 (pytorch) (0) | 2022.11.01 |