Zhang, Y., Xiang, T., Hospedales, T.M., & Lu, H. (2017). Deep Mutual Learning. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 4320-4328.
위의 논문에 대한 간단한 리뷰 및 코드 실습입니다
Knowledge Distillation Survey 논문 (Gou, J., Yu, B., Maybank, S.J., & Tao, D. (2020). Knowledge Distillation: A Survey. International Journal of Computer Vision, 129, 1789 - 1819.) 을 리뷰하던 중, DML 논문이 다른 KD 기법에 비해 Student 모델의 성능 향상이 제일 큰 것을 확인할 수 있었다.
그렇게 해서 보게된 DML (Deep Mutual Learning) 논문.
이 게시글에서는 Deep한 논문 리뷰 보다는 실습 위주로 선보일 것이다.
아키텍쳐는 위와 같다.
기존 Knowledge Distillation 과는 달리, Student 모델, Teacher 모델 구분하지 않으며, 신경망 2개를 동시에 학습하여 Kullback-Leibler Divergence (KL 발산) 값을 최소화하는 방향으로 서로 간의 예측 분포 (p1, p2) 을 학습한다.
DML의 장점이라고 하면 다음과 같다.
- 쌍방향 학습으로 일반화 성능의 향상
- 모델 구조가 달라도 학습 가능 (CNN <-> Transformer)
- 안정적인 일반화된 모델 학습 가능
그러나 단점도 존재한다.
- 동시에 학습하므로 리소스 과다 필요
코드 실습
공부도 할겸 코드로 뚝딱 해봤다. (ChatGPT는 마법의 요술 방망이가 맞다!..)
처음에 실습할 때에는 코드 작성 후 러프하게 리뷰했지만, 블로그에 글을 쓰면서 프로세스 플로우와 각 레이어의 원리, 코드 설명에 대해 자세하게 주석처리를 해놓았다.
모델선정
ResNet 모델과 MobileNet 모델로, CIFAR100 데이터셋을 활용하여 성능향상을 테스트 해보았다.
아래와 같이 ResNet 의 기본 구성요소인 Basic Block 을 정의했다.
# BasicBlock 클래스 정의 (ResNet의 기본 구성 요소)
class BasicBlock(nn.Module):
expansion = 1 # 확장 계수 (필터의 크기를 변경하는 데 사용)
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
# 첫 번째 3x3 컨볼루션 레이어: 입력 채널(in_planes) -> 출력 채널(planes), stride와 padding 설정
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes) # 첫 번째 Batch Normalization 레이어
# 두 번째 3x3 컨볼루션 레이어: 출력 채널 유지
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes) # 두 번째 Batch Normalization 레이어
# Shortcut 경로: 입력과 출력의 크기가 다르면 크기를 조정하는 레이어 추가
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes: # 크기 또는 채널 수가 다를 경우
self.shortcut = nn.Sequential(
# 1x1 컨볼루션: 입력 채널(in_planes) -> 출력 채널(planes * expansion), stride 적용
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes) # Batch Normalization
)
def forward(self, x):
# 입력 데이터에 대해 첫 번째 컨볼루션과 ReLU 활성화 함수 적용
out = F.relu(self.bn1(self.conv1(x)))
# 두 번째 컨볼루션 및 Batch Normalization 적용
out = self.bn2(self.conv2(out))
# Shortcut 경로를 통해 입력과 현재 출력 더하기
out += self.shortcut(x)
# 다시 ReLU 활성화 함수 적용
out = F.relu(out)
return out
그리고 아래와 같이 ResNet32 모델을 구성했다.
# ResNet32 모델 정의
class ResNet32(nn.Module):
def __init__(self, num_classes=100):
super(ResNet32, self).__init__()
self.in_planes = 16 # 입력 채널 크기 초기화
# 첫 번째 컨볼루션 레이어와 Batch Normalization
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) # 3채널(RGB) -> 16채널로 변환
self.bn1 = nn.BatchNorm2d(16)
# 세 개의 레이어 블록 생성 (각 블록은 여러 BasicBlock으로 구성)
# layer1: 채널 16 유지, 5개의 블록, stride=1
self.layer1 = self._make_layer(BasicBlock, 16, 5, stride=1)
# layer2: 채널 32로 증가, 5개의 블록, stride=2 (공간 크기 절반으로 감소)
self.layer2 = self._make_layer(BasicBlock, 32, 5, stride=2)
# layer3: 채널 64로 증가, 5개의 블록, stride=2
self.layer3 = self._make_layer(BasicBlock, 64, 5, stride=2)
# Adaptive Average Pooling (출력 크기를 1x1로 조정)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Fully Connected 레이어 (64채널 -> 클래스 수(num_classes) 출력)
self.fc = nn.Linear(64 * BasicBlock.expansion, num_classes)
# 블록 생성 함수: 특정 채널 크기와 블록 수로 BasicBlock을 연결
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1) # 첫 블록만 stride 적용
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride)) # BasicBlock 추가
self.in_planes = planes * block.expansion # 다음 블록의 입력 크기 업데이트
return nn.Sequential(*layers)
def forward(self, x):
# 첫 번째 컨볼루션과 BatchNorm
out = F.relu(self.bn1(self.conv1(x)))
# 세 개의 레이어 블록 통과
out = self.layer1(out) # 첫 번째 레이어 블록
out = self.layer2(out) # 두 번째 레이어 블록
out = self.layer3(out) # 세 번째 레이어 블록
# Adaptive Average Pooling 적용
out = self.avgpool(out)
# 1D 벡터로 변환
out = out.view(out.size(0), -1)
# Fully Connected 레이어로 출력
out = self.fc(out)
return out
pytorch hub에서 MobileNet 모델을 불러오고
Loss 와 optimizer, learning rate를 설정해준다
그리고 Mutual 학습 Loss를 다음과 같이 정의한다
def mutual_learning_loss(output1, output2):
kl_loss = nn.KLDivLoss(reduction='batchmean')
return kl_loss(nn.functional.log_softmax(output1, dim=1), nn.functional.softmax(output2, dim=1)) + \
kl_loss(nn.functional.log_softmax(output2, dim=1), nn.functional.softmax(output
각각 모델의 학습 epoch를 150번으로 설정 후 학습을 진행하였고, ResNet의 DML 전, 후 Accuracy와 Loss를 확인하였다.
결과는 다음과 같다
ResNet32 Independent: 52.73%
ResNet32 after DML: 56.77%
4% 가 상승했다!
물론 학습을 위한 시간이 더 걸리긴 했지만 유의미한 성능 향상이 있다는 것을 확인할 수 있었따.