[Python]Pytorch - RuntimeError:Error(s) in loading state_dict ... : Missing key(s) in state_dict: ... Unexpected key(s) in state_dict:... GPU 병렬 사용 문제
·
Programming/Python
모델을 학습시킨 후 저장하고, 다시 불러오는 중에 아래와 같은 문제가 발생했다. torch.save(model.state_dict, '~~.pt')와 같은 방식으로 저장했고, model.load_state_dict(torch.load('~~.pt'))으로 불러왔을 뿐인데 에러는 다음과 같았다. 자세히 보면, model의 state_dict의 키가 맞지 않다는 것을 알 수 있다. 저장한 모델은 'module.' 키가 앞에 붙어있는 반면, 불러올 모델은 키가 붙지 않아 매핑이 안되는 문제였다. 이러한 문제가 생긴 이유는, 다중 GPU를 사용할 때 발생하는 것으로 파악됐다. 다중 GPU를 사용하면서 torch.nn.DataParallel을 사용하면 모델의 state_dict는 model.module의 형태로 ..