2022. 11. 21.

UNETR(UNEt TRansformers)은 그 이름처럼 UNet 형태의 아키텍쳐이고, encoding 부분을 transformer 구조로 대체하여 feature map을 추출하는 것이 특징입니다.
본 내용은 UNETR 를 pytorch로 구현하는 것을 정리하였습니다.
핵심적인 부분인 ViT 구현을 아래 블로그(제 블로그 ㅎㅎ;;)를 참고하시면 더 자세하게 볼 수 있습니다.

UNETR 논문은 아래 링크에 있습니다.

Architecture의 주요 hyperparameters

  • image size = (H:224, W:224, D:224)
  • Layers = 12
  • Multi-head = 8
  • Embedding length = 768
  • Patch size = 16
UNETR achitecture

기본적인 conv3d block들을 선언해줍니다.
위 그림에서 초록색, 파란색, 노란색, 회색 notation에 구현에 필요한 블럭에 해당합니다.

import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
from torch import Tensor

class SingleDeconv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)

    def forward(self, x):
        return self.block(x)

class SingleConv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size):
        self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
                               padding=((kernel_size - 1) // 2))

    def forward(self, x):
        return self.block(x)

class Conv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3):
        self.block = nn.Sequential(
            SingleConv3DBlock(in_planes, out_planes, kernel_size),

    def forward(self, x):
        return self.block(x)
class Deconv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3):
        self.block = nn.Sequential(
            SingleDeconv3DBlock(in_planes, out_planes),
            SingleConv3DBlock(out_planes, out_planes, kernel_size),

    def forward(self, x):
        return self.block(x)


1. Embedding

  1. 인풋 이미지들을 패치로 나눠주고 position embedding과 더해줍니다.
  2. segmentation task이기 때문에 class token은 사용하지 않습니다.
class Embeddings(nn.Module):
    def __init__(self, input_shape, patch_size=16, embed_dim=768, dropout=0.):
        self.patch_size = patch_size
        self.in_channels = input_shape[-4]
        self.n_patches = int((input_shape[-1] * input_shape[-2] * input_shape[-3]) / (patch_size * patch_size * patch_size))
        self.embed_dim = embed_dim
        self.patch_embeddings = nn.Conv3d(in_channels=self.in_channels, out_channels=self.embed_dim,
                                          kernel_size=self.patch_size, stride=self.patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, self.embed_dim))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.patch_embeddings(x)
        x = rearrange(x, "b n h w d -> b (h w d) n")
        # batch, embed_dim, height/patch, width/patch, depth/patch
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings
## input ##
shape = (1,1,224,224,224)
x = torch.rand(shape)
patch_size = 16

E = Embeddings(x.shape[1:])
summary(E, x.shape[1:], device='cpu')


2. Multi Head Attention Block (MHA)

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
MHA = MultiHeadAttention()
summary(MHA, embedding_x.shape[1:], device='cpu')


3. Feed Forward Block (FF)

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int = 768, expansion: int = 4, drop_p: float = 0.):
            nn.Linear(emb_size, expansion * emb_size),
            nn.Linear(expansion * emb_size, emb_size),
FF = FeedForwardBlock()
summary(FF, embedding_x.shape[1:], device='cpu')


4. Transformer Block

MHA 블럭과 FF 블럭을 하나로 묶어줍니다.
한편, 각각 3,6,9,12번째의 transformer block의 feature을 뽑아줘야하기 때문에, 해당하는 layer의 feature map을 리스트 형태로 받습니다.

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, depth=12, dropout=0., extract_layers=[3,6,9,12]):
        self.layers = nn.ModuleList([])
        for _ in range(depth):
                PreNorm(embed_dim, MultiHeadAttention(embed_dim, num_heads, dropout)),
                PreNorm(embed_dim, FeedForwardBlock(embed_dim, expansion=4))
        self.extract_layers = extract_layers
    def forward(self, x):
        extract_layers = []
        for cnt, (attn, ff) in enumerate(self.layers):
            x = attn(x) + x
            x = ff(x) + x
            if cnt+1 in self.extract_layers:
        return extract_layers
TB = TransformerBlock()
summary(TB, embedding_x.shape[1:], device='cpu')
        Layer (type)               Output Shape         Param #
         LayerNorm-1            [-1, 2744, 768]           1,536
            Linear-2           [-1, 2744, 2304]       1,771,776
           Dropout-3        [-1, 8, 2744, 2744]               0
            Linear-4            [-1, 2744, 768]         590,592

========================= (중략) ==============================

MultiHeadAttention-137            [-1, 2744, 768]               0
         PreNorm-138            [-1, 2744, 768]               0
       LayerNorm-139            [-1, 2744, 768]           1,536
          Linear-140           [-1, 2744, 3072]       2,362,368
            GELU-141           [-1, 2744, 3072]               0
         Dropout-142           [-1, 2744, 3072]               0
          Linear-143            [-1, 2744, 768]       2,360,064
         PreNorm-144            [-1, 2744, 768]               0
Total params: 85,054,464
Trainable params: 85,054,464
Non-trainable params: 0
Input size (MB): 8.04
Forward/backward pass size (MB): 9759.42
Params size (MB): 324.46
Estimated Total Size (MB): 10091.92

각각 3, 6, 9, 12번째 layer에서 feature map들이 나온것을 확인할 수 있습니다.

embedding_x_list = TB(embedding_x)
for i in embedding_x_list:
torch.Size([1, 2744, 768])
torch.Size([1, 2744, 768])
torch.Size([1, 2744, 768])
torch.Size([1, 2744, 768])



위에서 선언해준 block들을 통해 UNETR을 빌드합니다.
옵션에 light_r 를 추가해주었는데, 논문 그대로 설정하면 메모리 문제로 빌드가 안더라구요.. (OOM)
그래서 conv3d의 채널을 조정해야했습니다. (monai의 UNETR도 채널수를 조정해서 가볍게 만들었더군요)
뿐만 아니라,, (224, 224, 224)의 이미지 사이즈도 vram에 들어가지 않아서 (RTX 3060은 웁니다..)
(128,128,128) 사이즈로 재조정해주었습니다..!

class UNETR(nn.Module):
    def __init__(self, img_shape=(224, 224, 224), input_dim=3, output_dim=3, 
                 embed_dim=768, patch_size=16, num_heads=8, dropout=0.1, light_r=4):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embed_dim = embed_dim
        self.img_shape = img_shape
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.num_layers = 12
        self.ext_layers = [3, 6, 9, 12]

        self.patch_dim = [int(x / patch_size) for x in img_shape]
        self.conv_channels = [int(i/light_r) for i in [32, 64, 128, 256, 512, 1024]]

        self.embedding = Embeddings((input_dim,*img_shape))
        # Transformer Encoder
        self.transformer = \

        # U-Net Decoder
        self.decoder0 = \
                Conv3DBlock(input_dim, self.conv_channels[0], 3),
                Conv3DBlock(self.conv_channels[0], self.conv_channels[1], 3)

        self.decoder3 = \
                Deconv3DBlock(embed_dim, self.conv_channels[2]),
                Deconv3DBlock(self.conv_channels[2], self.conv_channels[2]),
                Deconv3DBlock(self.conv_channels[2], self.conv_channels[2])

        self.decoder6 = \
                Deconv3DBlock(embed_dim, self.conv_channels[3]),
                Deconv3DBlock(self.conv_channels[3], self.conv_channels[3]),

        self.decoder9 = \
            Deconv3DBlock(embed_dim, self.conv_channels[4])

        self.decoder12_upsampler = \
            SingleDeconv3DBlock(embed_dim, self.conv_channels[4])

        self.decoder9_upsampler = \
                Conv3DBlock(self.conv_channels[5], self.conv_channels[3]),
                Conv3DBlock(self.conv_channels[3], self.conv_channels[3]),
                Conv3DBlock(self.conv_channels[3], self.conv_channels[3]),
                SingleDeconv3DBlock(self.conv_channels[3], self.conv_channels[3])

        self.decoder6_upsampler = \
                Conv3DBlock(self.conv_channels[4], self.conv_channels[2]),
                Conv3DBlock(self.conv_channels[2], self.conv_channels[2]),
                SingleDeconv3DBlock(self.conv_channels[2], self.conv_channels[2])

        self.decoder3_upsampler = \
                Conv3DBlock(self.conv_channels[3], self.conv_channels[1]),
                Conv3DBlock(self.conv_channels[1], self.conv_channels[1]),
                SingleDeconv3DBlock(self.conv_channels[1], self.conv_channels[1])

        self.decoder0_header = \
                Conv3DBlock(self.conv_channels[2], self.conv_channels[1]),
                Conv3DBlock(self.conv_channels[1], self.conv_channels[1]),
                SingleConv3DBlock(self.conv_channels[1], output_dim, 1)

    def forward(self, x):
        z0 = x
        x = self.embedding(x)
        z = self.transformer(x)
        z3, z6, z9, z12 = z
        z3 = z3.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
        z6 = z6.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
        z9 = z9.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
        z12 = z12.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)

        z12 = self.decoder12_upsampler(z12)
        z9 = self.decoder9(z9)
        z9 = self.decoder9_upsampler([z9, z12], dim=1))
        z6 = self.decoder6(z6)
        z6 = self.decoder6_upsampler([z6, z9], dim=1))
        z3 = self.decoder3(z3)
        z3 = self.decoder3_upsampler([z3, z6], dim=1))
        z0 = self.decoder0(z0)
        output = self.decoder0_header([z0, z3], dim=1))
        return output
from torchsummary import summary

# x = torch.rand(1,1,224,224,224)
x = torch.rand(1,1,128,128,128)
# x = torch.rand(1,1,64,64,64)
model = UNETR(img_shape=x.shape[2:], input_dim=x.shape[1], output_dim=4, 
              embed_dim=768, patch_size=16, num_heads=8, dropout=0., light_r=4)
summary(model, x.shape[1:], device='cpu')
        Layer (type)               Output Shape         Param #
            Conv3d-1         [-1, 768, 8, 8, 8]       3,146,496
           Dropout-2             [-1, 512, 768]               0
        Embeddings-3             [-1, 512, 768]               0
         LayerNorm-4             [-1, 512, 768]           1,536
            Linear-5            [-1, 512, 2304]       1,771,776
           Dropout-6          [-1, 8, 512, 512]               0
            Linear-7             [-1, 512, 768]         590,592
MultiHeadAttention-8             [-1, 512, 768]               0
           PreNorm-9             [-1, 512, 768]               0
        LayerNorm-10             [-1, 512, 768]           1,536
           Linear-11            [-1, 512, 3072]       2,362,368
             GELU-12            [-1, 512, 3072]               0

========================== (중략) ============================

     BatchNorm3d-251    [-1, 16, 128, 128, 128]              32
            ReLU-252    [-1, 16, 128, 128, 128]               0
     Conv3DBlock-253    [-1, 16, 128, 128, 128]               0
          Conv3d-254     [-1, 4, 128, 128, 128]              68
SingleConv3DBlock-255     [-1, 4, 128, 128, 128]               0
Total params: 92,065,812
Trainable params: 92,065,812
Non-trainable params: 0
Input size (MB): 8.00
Forward/backward pass size (MB): 7376.00
Params size (MB): 351.20
Estimated Total Size (MB): 7735.20

데이터 준비과정부터 전처리, 학습까지의 정리된 코드를 github에 올려두었습니다!

결과만 보자면..
MSD dataset를 활용했는데 복부 CT 에서 spleen(비장) segmentation하는 문제입니다.

1,4,7번째 행은 복부 CT(axial plane), 2,5,8번째 행은 ground truth, 3,6,9번째 행은 prediction


학습이 어느정도 되는구나 확인할 수 있었고, 
논문과 implementation detail이 다르다는 점은 참고해주세요!
UNETR에서 embedding 차원이 '768' 라는 이유에 대해 고민이 됐었습니다.
ViT에서는 patch size를 16으로 주었기 때문에, 패치들이 높이(16) x 너비(16) x 채널(3) 로 768 형태가 나오는 것이기 때문입니다.
당연히 UNETR에서도 그렇게해야 마땅하고 생각했는데, 
3D 이미지를 받는 UNETR에서 그렇게 하게 되면 embedding 차원이 급격하게 늘어나게 됩니다.
(높이(16) x 너비(16) x 깊이(16) x 채널(3) = 12,288)
이렇게 되면 embedding들은 주로 fully connected layer 연산을 수행하기 때문에, 연산량이 엄청나게 늘어나게 됩니다.
그래서 기존과 같이 768 길이의 차원을 적용한 것이 아닌가 싶습니다.
