[Python]Pytorch - RuntimeError:Error(s) in loading state_dict ... : Missing key(s) in state_dict: ... Unexpected key(s) in state_dict:... GPU 병렬 사용 문제
·
Programming/Python
모델을 학습시킨 후 저장하고, 다시 불러오는 중에 아래와 같은 문제가 발생했다. torch.save(model.state_dict, '~~.pt')와 같은 방식으로 저장했고, model.load_state_dict(torch.load('~~.pt'))으로 불러왔을 뿐인데 에러는 다음과 같았다. 자세히 보면, model의 state_dict의 키가 맞지 않다는 것을 알 수 있다. 저장한 모델은 'module.' 키가 앞에 붙어있는 반면, 불러올 모델은 키가 붙지 않아 매핑이 안되는 문제였다. 이러한 문제가 생긴 이유는, 다중 GPU를 사용할 때 발생하는 것으로 파악됐다. 다중 GPU를 사용하면서 torch.nn.DataParallel을 사용하면 모델의 state_dict는 model.module의 형태로 ..
[파이토치로 시작하는 딥러닝 기초]10.3 ImageFolder / 모델 저장 / 모델 불러오기
·
AI Study/DL_Basic
ImageFolder 나만의 데이터 셋 준비하기 ImageFolder란? 로컬에 저장된 이미지 데이터를 불러올 때 사용하는 pytorch 라이브러리 데이터를 준비할 때에는 명확하게 구분되는 사진을 사용해야 한다. 위와 같이 구분하는 label의 class 개수에 따라 folder를 생성하고 그 안에 해당 라벨에 맞는 이미지를 삽입한다. 데이터 불러오기 실습 import torchvision from torchvision import transforms from torch.utils.data import DataLoader from matplotlib.pyplot import imshow %matplotlib inline train_data = torchvision.datasets.ImageFolder(r..