일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- decorater
- nibabel
- parrec
- non-parametric model
- MRI
- parametric model
- Surgical video analysis
- genetic epidemiology
- nlp
- 비모수적 모델
- 모수적 모델
- paper review
- TabNet
- deep learning #segmentation #sementic #pytorch #UNETR #transformer #UNET #3D #3D medical image
- monai
- tabular
- 코드오류
- words encoding
- 유전역학
- precision #정밀도 #민감도 #sensitivity #특이도 #specifisity #F1 score #dice score #confusion matrix #recall #PR-AUC #ROC-AUC #PR curve #ROC curve #NPV #PPV
- 데코레이터
- 확산텐서영상
- 파이썬
- nfiti
- Phase recognition
- parer review
- PYTHON
- 확산강조영상
- MICCAI
- TeCNO
- Today
- Total
KimbgAI
[ML][pytorch] UNETR(21.03); UNEt TRansformers 코드 설명 및 구현 본문
[ML][pytorch] UNETR(21.03); UNEt TRansformers 코드 설명 및 구현
KimbgAI 2022. 11. 21. 02:22UNETR(UNEt TRansformers)은 그 이름처럼 UNet 형태의 아키텍쳐이고, encoding 부분을 transformer 구조로 대체하여 feature map을 추출하는 것이 특징입니다.
본 내용은 UNETR 를 pytorch로 구현하는 것을 정리하였습니다.
핵심적인 부분인 ViT 구현을 아래 블로그(제 블로그 ㅎㅎ;;)를 참고하시면 더 자세하게 볼 수 있습니다.
https://kimbg.tistory.com/31
UNETR 논문은 아래 링크에 있습니다.
https://arxiv.org/abs/2103.10504
Architecture의 주요 hyperparameters
- image size = (H:224, W:224, D:224)
- Layers = 12
- Multi-head = 8
- Embedding length = 768
- Patch size = 16
기본적인 conv3d block들을 선언해줍니다.
위 그림에서 초록색, 파란색, 노란색, 회색 notation에 구현에 필요한 블럭에 해당합니다.
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
from torch import Tensor
class SingleDeconv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super().__init__()
self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)
def forward(self, x):
return self.block(x)
class SingleConv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size):
super().__init__()
self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
padding=((kernel_size - 1) // 2))
def forward(self, x):
return self.block(x)
class Conv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3):
super().__init__()
self.block = nn.Sequential(
SingleConv3DBlock(in_planes, out_planes, kernel_size),
nn.BatchNorm3d(out_planes),
nn.ReLU(True)
)
def forward(self, x):
return self.block(x)
class Deconv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3):
super().__init__()
self.block = nn.Sequential(
SingleDeconv3DBlock(in_planes, out_planes),
SingleConv3DBlock(out_planes, out_planes, kernel_size),
nn.BatchNorm3d(out_planes),
nn.ReLU(True)
)
def forward(self, x):
return self.block(x)
1. Embedding
- 인풋 이미지들을 패치로 나눠주고 position embedding과 더해줍니다.
- segmentation task이기 때문에 class token은 사용하지 않습니다.
class Embeddings(nn.Module):
def __init__(self, input_shape, patch_size=16, embed_dim=768, dropout=0.):
super().__init__()
self.patch_size = patch_size
self.in_channels = input_shape[-4]
self.n_patches = int((input_shape[-1] * input_shape[-2] * input_shape[-3]) / (patch_size * patch_size * patch_size))
self.embed_dim = embed_dim
self.patch_embeddings = nn.Conv3d(in_channels=self.in_channels, out_channels=self.embed_dim,
kernel_size=self.patch_size, stride=self.patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, self.embed_dim))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.patch_embeddings(x)
x = rearrange(x, "b n h w d -> b (h w d) n")
# batch, embed_dim, height/patch, width/patch, depth/patch
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
## input ##
shape = (1,1,224,224,224)
x = torch.rand(shape)
patch_size = 16
E = Embeddings(x.shape[1:])
summary(E, x.shape[1:], device='cpu')
2. Multi Head Attention Block (MHA)
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
# fuse the queries, keys and values in one matrix
self.qkv = nn.Linear(emb_size, emb_size * 3)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
# split keys, queries and values in num_heads
qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
queries, keys, values = qkv[0], qkv[1], qkv[2]
# sum up over the last axis
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1/2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
# sum up over the third axis
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
MHA = MultiHeadAttention()
summary(MHA, embedding_x.shape[1:], device='cpu')
3. Feed Forward Block (FF)
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size: int = 768, expansion: int = 4, drop_p: float = 0.):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
FF = FeedForwardBlock()
summary(FF, embedding_x.shape[1:], device='cpu')
4. Transformer Block
MHA 블럭과 FF 블럭을 하나로 묶어줍니다.
한편, 각각 3,6,9,12번째의 transformer block의 feature을 뽑아줘야하기 때문에, 해당하는 layer의 feature map을 리스트 형태로 받습니다.
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class TransformerBlock(nn.Module):
def __init__(self, embed_dim=768, num_heads=8, depth=12, dropout=0., extract_layers=[3,6,9,12]):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(embed_dim, MultiHeadAttention(embed_dim, num_heads, dropout)),
PreNorm(embed_dim, FeedForwardBlock(embed_dim, expansion=4))
]))
self.extract_layers = extract_layers
def forward(self, x):
extract_layers = []
for cnt, (attn, ff) in enumerate(self.layers):
x = attn(x) + x
x = ff(x) + x
if cnt+1 in self.extract_layers:
extract_layers.append(x)
return extract_layers
TB = TransformerBlock()
summary(TB, embedding_x.shape[1:], device='cpu')
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
LayerNorm-1 [-1, 2744, 768] 1,536
Linear-2 [-1, 2744, 2304] 1,771,776
Dropout-3 [-1, 8, 2744, 2744] 0
Linear-4 [-1, 2744, 768] 590,592
========================= (중략) ==============================
MultiHeadAttention-137 [-1, 2744, 768] 0
PreNorm-138 [-1, 2744, 768] 0
LayerNorm-139 [-1, 2744, 768] 1,536
Linear-140 [-1, 2744, 3072] 2,362,368
GELU-141 [-1, 2744, 3072] 0
Dropout-142 [-1, 2744, 3072] 0
Linear-143 [-1, 2744, 768] 2,360,064
PreNorm-144 [-1, 2744, 768] 0
================================================================
Total params: 85,054,464
Trainable params: 85,054,464
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 8.04
Forward/backward pass size (MB): 9759.42
Params size (MB): 324.46
Estimated Total Size (MB): 10091.92
----------------------------------------------------------------
각각 3, 6, 9, 12번째 layer에서 feature map들이 나온것을 확인할 수 있습니다.
embedding_x_list = TB(embedding_x)
for i in embedding_x_list:
print(i.shape)
torch.Size([1, 2744, 768])
torch.Size([1, 2744, 768])
torch.Size([1, 2744, 768])
torch.Size([1, 2744, 768])
UNETR
위에서 선언해준 block들을 통해 UNETR을 빌드합니다.
옵션에 light_r 를 추가해주었는데, 논문 그대로 설정하면 메모리 문제로 빌드가 안더라구요.. (OOM)
그래서 conv3d의 채널을 조정해야했습니다. (monai의 UNETR도 채널수를 조정해서 가볍게 만들었더군요)
뿐만 아니라,, (224, 224, 224)의 이미지 사이즈도 vram에 들어가지 않아서 (RTX 3060은 웁니다..)
(128,128,128) 사이즈로 재조정해주었습니다..!
class UNETR(nn.Module):
def __init__(self, img_shape=(224, 224, 224), input_dim=3, output_dim=3,
embed_dim=768, patch_size=16, num_heads=8, dropout=0.1, light_r=4):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.embed_dim = embed_dim
self.img_shape = img_shape
self.patch_size = patch_size
self.num_heads = num_heads
self.dropout = dropout
self.num_layers = 12
self.ext_layers = [3, 6, 9, 12]
self.patch_dim = [int(x / patch_size) for x in img_shape]
self.conv_channels = [int(i/light_r) for i in [32, 64, 128, 256, 512, 1024]]
self.embedding = Embeddings((input_dim,*img_shape))
# Transformer Encoder
self.transformer = \
TransformerBlock(
)
# U-Net Decoder
self.decoder0 = \
nn.Sequential(
Conv3DBlock(input_dim, self.conv_channels[0], 3),
Conv3DBlock(self.conv_channels[0], self.conv_channels[1], 3)
)
self.decoder3 = \
nn.Sequential(
Deconv3DBlock(embed_dim, self.conv_channels[2]),
Deconv3DBlock(self.conv_channels[2], self.conv_channels[2]),
Deconv3DBlock(self.conv_channels[2], self.conv_channels[2])
)
self.decoder6 = \
nn.Sequential(
Deconv3DBlock(embed_dim, self.conv_channels[3]),
Deconv3DBlock(self.conv_channels[3], self.conv_channels[3]),
)
self.decoder9 = \
Deconv3DBlock(embed_dim, self.conv_channels[4])
self.decoder12_upsampler = \
SingleDeconv3DBlock(embed_dim, self.conv_channels[4])
self.decoder9_upsampler = \
nn.Sequential(
Conv3DBlock(self.conv_channels[5], self.conv_channels[3]),
Conv3DBlock(self.conv_channels[3], self.conv_channels[3]),
Conv3DBlock(self.conv_channels[3], self.conv_channels[3]),
SingleDeconv3DBlock(self.conv_channels[3], self.conv_channels[3])
)
self.decoder6_upsampler = \
nn.Sequential(
Conv3DBlock(self.conv_channels[4], self.conv_channels[2]),
Conv3DBlock(self.conv_channels[2], self.conv_channels[2]),
SingleDeconv3DBlock(self.conv_channels[2], self.conv_channels[2])
)
self.decoder3_upsampler = \
nn.Sequential(
Conv3DBlock(self.conv_channels[3], self.conv_channels[1]),
Conv3DBlock(self.conv_channels[1], self.conv_channels[1]),
SingleDeconv3DBlock(self.conv_channels[1], self.conv_channels[1])
)
self.decoder0_header = \
nn.Sequential(
Conv3DBlock(self.conv_channels[2], self.conv_channels[1]),
Conv3DBlock(self.conv_channels[1], self.conv_channels[1]),
SingleConv3DBlock(self.conv_channels[1], output_dim, 1)
)
def forward(self, x):
z0 = x
x = self.embedding(x)
z = self.transformer(x)
z3, z6, z9, z12 = z
z3 = z3.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
z6 = z6.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
z9 = z9.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
z12 = z12.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
z12 = self.decoder12_upsampler(z12)
z9 = self.decoder9(z9)
z9 = self.decoder9_upsampler(torch.cat([z9, z12], dim=1))
z6 = self.decoder6(z6)
z6 = self.decoder6_upsampler(torch.cat([z6, z9], dim=1))
z3 = self.decoder3(z3)
z3 = self.decoder3_upsampler(torch.cat([z3, z6], dim=1))
z0 = self.decoder0(z0)
output = self.decoder0_header(torch.cat([z0, z3], dim=1))
return output
from torchsummary import summary
# x = torch.rand(1,1,224,224,224)
x = torch.rand(1,1,128,128,128)
# x = torch.rand(1,1,64,64,64)
model = UNETR(img_shape=x.shape[2:], input_dim=x.shape[1], output_dim=4,
embed_dim=768, patch_size=16, num_heads=8, dropout=0., light_r=4)
summary(model, x.shape[1:], device='cpu')
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv3d-1 [-1, 768, 8, 8, 8] 3,146,496
Dropout-2 [-1, 512, 768] 0
Embeddings-3 [-1, 512, 768] 0
LayerNorm-4 [-1, 512, 768] 1,536
Linear-5 [-1, 512, 2304] 1,771,776
Dropout-6 [-1, 8, 512, 512] 0
Linear-7 [-1, 512, 768] 590,592
MultiHeadAttention-8 [-1, 512, 768] 0
PreNorm-9 [-1, 512, 768] 0
LayerNorm-10 [-1, 512, 768] 1,536
Linear-11 [-1, 512, 3072] 2,362,368
GELU-12 [-1, 512, 3072] 0
========================== (중략) ============================
BatchNorm3d-251 [-1, 16, 128, 128, 128] 32
ReLU-252 [-1, 16, 128, 128, 128] 0
Conv3DBlock-253 [-1, 16, 128, 128, 128] 0
Conv3d-254 [-1, 4, 128, 128, 128] 68
SingleConv3DBlock-255 [-1, 4, 128, 128, 128] 0
================================================================
Total params: 92,065,812
Trainable params: 92,065,812
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 8.00
Forward/backward pass size (MB): 7376.00
Params size (MB): 351.20
Estimated Total Size (MB): 7735.20
----------------------------------------------------------------
데이터 준비과정부터 전처리, 학습까지의 정리된 코드를 github에 올려두었습니다!
https://github.com/kimbgAI/Segmentation3D
결과만 보자면..
MSD dataset를 활용했는데 복부 CT 에서 spleen(비장) segmentation하는 문제입니다.
학습이 어느정도 되는구나 확인할 수 있었고,
논문과 implementation detail이 다르다는 점은 참고해주세요!
끝!!
여담으로,
UNETR에서 embedding 차원이 '768' 라는 이유에 대해 고민이 됐었습니다.
ViT에서는 patch size를 16으로 주었기 때문에, 패치들이 높이(16) x 너비(16) x 채널(3) 로 768 형태가 나오는 것이기 때문입니다.
당연히 UNETR에서도 그렇게해야 마땅하고 생각했는데,
3D 이미지를 받는 UNETR에서 그렇게 하게 되면 embedding 차원이 급격하게 늘어나게 됩니다.
(높이(16) x 너비(16) x 깊이(16) x 채널(3) = 12,288)
이렇게 되면 embedding들은 주로 fully connected layer 연산을 수행하기 때문에, 연산량이 엄청나게 늘어나게 됩니다.
그래서 기존과 같이 768 길이의 차원을 적용한 것이 아닌가 싶습니다.
'machine learning' 카테고리의 다른 글
이미지 내 색상 별 픽셀 수 확인하기 (0) | 2023.03.09 |
---|---|
[ML] 분류 평가 지표 정리(sensitivity, recall, precision, specificity, f1 score, NPV, PPV (0) | 2022.11.22 |
[ML] Data augmentation for 3D medical image (3) | 2022.11.17 |
[ML] ViT(20.10); Vision Transformer 코드 구현 및 설명 with pytorch (2) | 2022.11.10 |
[ML] Dice loss & Dice Score with monai, pytorch (0) | 2022.11.08 |