모델을 학습시킨 후 저장하고, 다시 불러오는 중에 아래와 같은 문제가 발생했다.
torch.save(model.state_dict, '~~.pt')와 같은 방식으로 저장했고,
model.load_state_dict(torch.load('~~.pt'))으로 불러왔을 뿐인데 에러는 다음과 같았다.
자세히 보면, model의 state_dict의 키가 맞지 않다는 것을 알 수 있다. 저장한 모델은 'module.' 키가 앞에 붙어있는 반면, 불러올 모델은 키가 붙지 않아 매핑이 안되는 문제였다.
이러한 문제가 생긴 이유는, 다중 GPU를 사용할 때 발생하는 것으로 파악됐다.
다중 GPU를 사용하면서 torch.nn.DataParallel을 사용하면 모델의 state_dict는 model.module의 형태로 저장된다. 이 때 module 형태의 state를 그냥 model.load_staete_dict()로 넣는다면 위와 같은 에러를 얻게 된다.
찾아보니 해결책이 몇가지 있었다.
1. torch.save(model.module.state_dict(), '~~.pt')로 module을 빼고 저장하기
먼저 저장을 통일되게 저장하는 방법이 있다.
nn.DataParallel로 학습시킨 모델이 model.module의 형태로 되어있다면, model.module.state_dict() 로 저장하면 state_dict에 module 키가 빠지게 된다. 조건문을 사용해서 nn.DataParallel을 사용한다면 module.state_dict로 저장하고, 미사용시에는 state_dict를 사용해서 저장하면 된다.
## model save
if isinstance(model, nn.DataParallel): ## 다중 GPU를 사용한다면
torch.save(model.module.state_dict(), '~~.pt') ## model.module 형태로 module.을 제거하고 저장
else:
torch.save(model.state_dict(), '~~.pt') ## 일반저장
## model load
if isinstance(model, nn.DataParallel): ## 다중 GPU를 사용하면
model.module.load_state_dict(torch.load('~~.pt')) ## model.module에 state_dict를 불러옴
else:
model.load_state_dict(torch.load('~~.pt')) ## 일반적으로 불러오기
2. model.load()한 후 module key 제거하기
이 부분은 조금 의아한(?) 야매스러운 방법으로 보여지긴 했는데, 파이토치 포럼에 있을 정도니 사용해 보기로 했다.
저장할 때에는 GPU 병렬에 상관없이, torch.save(model.state_dict(), '~~.pt')로 저장한다.
그러나 불러올 때 module이 있으면 module key를 제거하는 방법이다.
if isinstance(model, nn.DataParallel): # GPU 병렬사용 적용
model.load_state_dict(torch.load('~~.pt'))
else: # GPU 병렬사용을 안할 경우
state_dict = torch.load(model_parms.model_dir)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.` ## module 키 제거
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
이 외에도 아래 참고링크에 추가적인 설명을 글쓴이가 해놓았는데, GPU개수를 맞춰준다던가, 환경을 동일하게 해주는 방법이 있는데, 언제 모델을 학습하는, 돌리는 환경이 바뀔지 모르기 때문에 그것보다는 위 두가지 방법이 제일 쉽고 현명할 듯 하다.
참고링크:
'Code > Python' 카테고리의 다른 글
[Python] HTTP web server log dataframe으로 불러오기 (with pandas) (0) | 2022.02.13 |
---|---|
[Python] 주피터 노트북 테마 변경하기 (3) | 2021.05.17 |
[Python] Data Frame apply 함수 병렬처리 하는 방법 (0) | 2021.04.15 |
[Python] Numpy Float(소수) 출력 표현 설정하기 (0) | 2021.03.23 |
[Python] python 에서 이유를 알 수 없는 GPU 에러 정리(device-side assert triggered) (0) | 2021.02.04 |