평소에 파이토치에서 모델을 저장할 때 torch.save(모델명, 모델 경로)만 사용해서 pickle 파일로 저장을 했었다.
이번에 딥러닝을 배우면서 state_dict 함수란것을 배웠고, 그것이 뭔지 좀 더 자세히 기록하기 위해 이와 같은 포스팅을 하게 되었다.
모델을 저장하는데에는 두가지 방법이 있는 듯 하다.
- torch.save(model 명, 저장경로)
- 사실 torch.save()는 모델 뿐 아니라 모든 객체를 pickle 파일로 저장할 수 있는 함수이다. 확장자 명 또한 사용자가 지정 가능하다.
- 저장된 모델을 불러오는 데에는 torch.load(저장경로)를 사용하면 된다.
- torch.save( [model명].state_dict(), 저장경로)
- 모델의 매개변수들을 저장하는 방법
- 모델을 불러올 때에는 [새로운 모델 명].load_state_dict()를 사용해서 매개변수를 적용한다.
내가 포스팅하려는 것은 이 두번째 방식이다.
우선 state_dict가 무엇인지 부터 알아보도록 하자.
state_dict란?
state_dict란 torch.nn.Module에서 모델로 학습할 때 각 layer마다 텐서로 매핑되는 매개변수(예를 들어 가중치, 편향과 같은)를 python dictionary 타입으로 저장한 객체이다. 학습 가능한 매개변수를 갖는 계층만이 모델의 state_dict에 항목을 가진다.
torch.optim 또한 옵티마이저의 상태 뿐만 아니라 사요된 하이퍼 매개변수 정보다 포함된 state_dict를 갖는다.
한마디로 모델의 구조에 맞게 각 레이어마다의 매개변수를 tensor형태로 매핑해서 dictionary형태로 저장하는 것이다.
왜 그냥 torch.save()로 저장하는 것이 아니라 state_dict()로 저장하는지는 잘 모르겠지만... 조금 다르게 저장하고 불러오는 느낌이다.
아래 예시는 참고링크에 있는 코드와 설명을 참고하여 작성했다.
추론(inference)를 위해 모델 저장하기 & 불러오기
state_dict 저장하기 / 불러오기 (권장)
저장하기 :
torch.save(model.state_dict(), PATH)
불러오기 :
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
추론을 위해 모델을 저장할 때는 학습된 모델의 학습된 매개변수만 저장하면 됩니다. torch.save() 를 사용하여 모델의 state_dict 를 저장하는 것이 나중에 모델을 사용할 때 가장 유연하게 사용할 수 있는, 모델 저장 시 권장하는 방법입니다.
PyTorch에서는 모델을 저장할 때 .pt 또는 .pth 확장자를 사용하는 것이 일반적인 규칙입니다.
추론을 실행하기 전에는 반드시 model.eval() 을 호출하여 드롭아웃 및 배치 정규화를 평가 모드로 설정하여야 합니다. 이것을 하지 않으면 추론 결과가 일관성 없게 출력됩니다.
전체 모델 저장하기/불러오기
저장하기 :
torch.save(model, PATH)
불러오기 :
# 모델 클래스는 어딘가에 반드시 선언되어 있어야 합니다
model = torch.load(PATH)
model.eval()
이 저장하기/불러오기 과정은 가장 직관적인 문법을 사용하며 적은 양의 코드를 사용합니다. 이러한 방식으로 모델을 저장하는 것은 Python의 pickle 모듈을 사용하여 전체 모듈을 저장하게 됩니다. 하지만 pickle은 모델 그 자체를 저장하지 않기 때문에 직렬화된 데이터가 모델을 저장할 때 사용한 특정 클래스 및 디렉토리 경로(구조)에 얽매인다는 것이 이 방식의 단점입니다. 대신에 클래스가 위치한 파일의 경로를 저장해두고, 불러오는 시점에 사용합니다. 이러한 이유 때문에, 만들어둔 코드를 다른 프로젝트에서 사용하거나 리팩토링 후에 다양한 이유로 동작하지 않을 수 있습니다.
PyTorch에서는 모델을 저장할 때 .pt 또는 .pth 확장자를 사용하는 것이 일반적인 규칙입니다.
추론을 실행하기 전에는 반드시 model.eval() 을 호출하여 드롭아웃 및 배치 정규화를 평가 모드로 설정하여야 합니다. 이것을 하지 않으면 추론 결과가 일관성 없게 출력됩니다.
이렇게 다른 점을 이야기 해주는데, 현업에서 내가 사용할 때에는 torch.save()로 모델 자체를 저장하는 것 만으로도 큰 문제가 없었다. 오히려 그것을 더 많이 사용했다.
그 모델의 class가 있으면 다른 프로젝트에서 사용하더래도 큰 문제가 없을 것이라고 예상하는 바이다.
나중에 추가로 알게되는 사항이 있으면 재포스팅해야겠다.
참고 :
tutorials.pytorch.kr/beginner/saving_loading_models.html
'Code > Python' 카테고리의 다른 글
[Python] 파이썬 파일 크기 사이즈 구하기 - os.path.getsize() (0) | 2021.01.12 |
---|---|
[Python] 폴더 내 파일 리스트 가져오기 (os/glob) (0) | 2021.01.12 |
[Python] isinstance 함수 - 파이썬 자료형 확인하는 함수 (0) | 2021.01.06 |
[Python] enumerate 함수 (0) | 2020.12.31 |
[Python] Colab이란? Colab 구글 드라이브에서 사용하기 (0) | 2020.12.28 |