[Python]Pytorch - RuntimeError:Error(s) in loading state_dict ... : Missing key(s) in state_dict: ... Unexpected key(s) in state_dict:... GPU 병렬 사용 문제

2021. 4. 30. 10:41·Programming/Python
반응형

모델을 학습시킨 후 저장하고, 다시 불러오는 중에 아래와 같은 문제가 발생했다. 

 

torch.save(model.state_dict, '~~.pt')와 같은 방식으로 저장했고, 

model.load_state_dict(torch.load('~~.pt'))으로 불러왔을 뿐인데 에러는 다음과 같았다. 

자세히 보면, model의 state_dict의 키가 맞지 않다는 것을 알 수 있다. 저장한 모델은 'module.' 키가 앞에 붙어있는 반면, 불러올 모델은 키가 붙지 않아 매핑이 안되는 문제였다.

 

이러한 문제가 생긴 이유는, 다중 GPU를 사용할 때 발생하는 것으로 파악됐다. 

다중 GPU를 사용하면서 torch.nn.DataParallel을 사용하면 모델의 state_dict는 model.module의 형태로 저장된다. 이 때 module 형태의 state를 그냥 model.load_staete_dict()로 넣는다면 위와 같은 에러를 얻게 된다. 

찾아보니 해결책이 몇가지 있었다. 

 

1. torch.save(model.module.state_dict(), '~~.pt')로 module을 빼고 저장하기

먼저 저장을 통일되게 저장하는 방법이 있다.

nn.DataParallel로 학습시킨 모델이 model.module의 형태로 되어있다면, model.module.state_dict() 로 저장하면 state_dict에 module 키가 빠지게 된다. 조건문을 사용해서 nn.DataParallel을 사용한다면 module.state_dict로 저장하고, 미사용시에는 state_dict를 사용해서 저장하면 된다.

## model save
if isinstance(model, nn.DataParallel): ## 다중 GPU를 사용한다면
	torch.save(model.module.state_dict(), '~~.pt') ## model.module 형태로 module.을 제거하고 저장
else:
	torch.save(model.state_dict(), '~~.pt')  ## 일반저장
    
## model load
if isinstance(model, nn.DataParallel): ## 다중 GPU를 사용하면
	model.module.load_state_dict(torch.load('~~.pt')) ## model.module에 state_dict를 불러옴
else:
	model.load_state_dict(torch.load('~~.pt')) ## 일반적으로 불러오기

2. model.load()한 후 module key 제거하기

이 부분은 조금 의아한(?) 야매스러운 방법으로 보여지긴 했는데, 파이토치 포럼에 있을 정도니 사용해 보기로 했다. 

저장할 때에는 GPU 병렬에 상관없이, torch.save(model.state_dict(), '~~.pt')로 저장한다. 

그러나 불러올 때 module이 있으면 module key를 제거하는 방법이다. 

if isinstance(model, nn.DataParallel):  # GPU 병렬사용 적용
	model.load_state_dict(torch.load('~~.pt'))
else: # GPU 병렬사용을 안할 경우
	state_dict = torch.load(model_parms.model_dir) 
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
		name = k[7:] # remove `module.`  ## module 키 제거
		new_state_dict[name] = v
	model.load_state_dict(new_state_dict)

 

이 외에도 아래 참고링크에 추가적인 설명을 글쓴이가 해놓았는데, GPU개수를 맞춰준다던가, 환경을 동일하게 해주는 방법이 있는데, 언제 모델을 학습하는, 돌리는 환경이 바뀔지 모르기 때문에 그것보다는 위 두가지 방법이 제일 쉽고 현명할 듯 하다.

 

참고링크:

aigong.tistory.com/192

 

[Solution][Pytorch] RuntimeError: Error(s) in loading state_dict for ... : Missing key(s) in state_dict: ... Unexpected key(s

[Solution][Pytorch] RuntimeError: Error(s) in loading state_dict for ... :    Missing key(s) in state_dict: ... Unexpected key(s) in state_dict: ... 목차 분명 torch.save를 통해 model, optimizer..

aigong.tistory.com

discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/4

 

Missing keys & unexpected keys in state_dict when loading self trained model

Thanks for your suggestions! @ptrblck

discuss.pytorch.org

 

반응형

'Programming > Python' 카테고리의 다른 글

[Python] HTTP web server log dataframe으로 불러오기 (with pandas)  (0) 2022.02.13
[Python] 주피터 노트북 테마 변경하기  (3) 2021.05.17
[Python] Data Frame apply 함수 병렬처리 하는 방법  (0) 2021.04.15
[Python] Numpy Float(소수) 출력 표현 설정하기  (0) 2021.03.23
[Python] python 에서 이유를 알 수 없는 GPU 에러 정리(device-side assert triggered)  (0) 2021.02.04
'Programming/Python' 카테고리의 다른 글
  • [Python] HTTP web server log dataframe으로 불러오기 (with pandas)
  • [Python] 주피터 노트북 테마 변경하기
  • [Python] Data Frame apply 함수 병렬처리 하는 방법
  • [Python] Numpy Float(소수) 출력 표현 설정하기
자동화먹
자동화먹
많은 사람들에게 도움이 되는 생산적인 기록하기
    반응형
  • 자동화먹
    자동화먹의 생산적인 기록
    자동화먹
  • 전체
    오늘
    어제
    • 분류 전체보기 (144)
      • 생산성 & 자동화 툴 (30)
        • Notion (24)
        • Obsidian (0)
        • Make.com (1)
        • tips (5)
      • Programming (37)
        • Python (18)
        • Oracle (6)
        • Git (13)
      • AI Study (65)
        • DL_Basic (14)
        • ML_Basic (14)
        • NLP (21)
        • Marketing&Recommend (4)
        • chatGPT (0)
        • etc (12)
      • 주인장의 생각서랍 (10)
        • 생각정리 (4)
        • 독서기록 (6)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    Jupyter notebook
    git
    Transformer
    seq2seq
    노션첫걸음
    노션
    Google Cloud Platform
    gcp
    python기초
    파이토치
    nlp
    파이토치로 시작하는 딥러닝 기초
    딥러닝
    LSTM
    머신러닝
    cnn
    Python
    Github
    pytorch
    기초
    데이터베이스
    빅데이터분석
    데이터분석
    빅데이터
    ML
    dl
    git commit
    GPT
    notion
    자연어처리
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
자동화먹
[Python]Pytorch - RuntimeError:Error(s) in loading state_dict ... : Missing key(s) in state_dict: ... Unexpected key(s) in state_dict:... GPU 병렬 사용 문제
상단으로

티스토리툴바