KimbgAI

[오류해결] monai metrics 중 DiceMetric의 y, y_pred 인식 오류 본문

machine learning

[오류해결] monai metrics 중 DiceMetric의 y, y_pred 인식 오류

KimbgAI 2024. 3. 27. 13:09
반응형

도무지 이해가 안가는 경우였다.

 

monia의 DiceMetric를 사용하던 중 y에 label을 넣고 y_pred에 model의 output을 넣는데 계산이 이상하게 나왔던 것.

더 이상했던 것은 y에 model의 output을 넣고, y_pred에 label을 넣으니 제대로 작동하더라는 것이다.

두 눈을 의심하고 코드를 뜯어봐도 결과는 마찬가지였다.

기존에는 동일한 코드로 잘만 사용했었는데 말이다.

 

 

해결한 결과부터 말하면,

1. output을 argmax 해야했음

2. 이제서야 문제가 밝혀는 까닭은 기존에 사용했던 monai 버전과 달라서 그랬던 것. (기존에는 1.0.0 사용, 현재는 1.3.0 사용)

 

 

현상을 살펴보면..

각각 target, output, scratch 을 시각화면 아래와 같다.

import pickle
import matplotlib.pyplot as plt

with open(r'target.pkl', 'rb') as f:
    target = pickle.load(f)
print('target')
plt.imshow(target[:,:,0])
plt.show()

with open(r'output.pkl', 'rb') as f:
    output = pickle.load(f)
print('output')
plt.imshow(output[:,:,0])
plt.show()

with open(r'init_feature.pkl', 'rb') as f:
    init_feature = pickle.load(f)
print('scratch')
plt.imshow(init_feature[:,:,0])
plt.show()

 

당연히 target과 ooutput의 dice score가 높게 나오고 scratch와는 낮게 나와야하는 것이 올바른 상황임.

 

 

하지만 monai의 dice score 값은 동일하게 나옴..

import torch

## dice score를 계산하기 전에 pytorch style로 변환해야함
target = torch.unsqueeze(torch.tensor(target).permute(2,0,1), 0)
output = torch.unsqueeze(torch.tensor(output).permute(2,0,1), 0)
init_feature = torch.unsqueeze(torch.tensor(init_feature).permute(2,0,1), 0)
print(target.shape) # B, C, W, H


from monai.metrics import DiceMetric
Dice = DiceMetric()

score = Dice(y_pred=output, y=target)
print(score)

score = Dice(y_pred=init_feature, y=target)
print(score)

 

 

 

아까 말했든이 y에 output을 넣고 y_pred에 target을 넣으면 오히려 정상적으로 나오는 모습..

score = Dice(y_pred=target, y=output)
print(score)

score = Dice(y_pred=target, y=init_feature)
print(score)

 

 

해결 방법은 아래와 같이 명시적으로 argmax를 해주면 되긴 한다.

def OutputPostProcess(output, num_classes):
    output_arg = torch.argmax(output, axis=1)
    output_arg = torch.nn.functional.one_hot(output_arg, num_classes=num_classes)
    output_arg = output_arg.permute(0,3,1,2)
    return output_arg

score = Dice(y_pred=OutputPostProcess(output, num_classes=2), y=target)
print(score)

score = Dice(y_pred=OutputPostProcess(init_feature, num_classes=2), y=target)
print(score)

 

 

 

monai 1.0.0 버전에서는 굳이 argmax를 해주지 않아도 알아서 잘 나옴. (UserWarning이 뜨긴 하지만)

 

 

 

오늘의 교훈!

기존 코드와 뭔가 잘 안 맞는다 싶으면 버전을 먼저 확인해볼것!

 

끝!

반응형
Comments