[Python] pytorch 모델 저장하기 - state_dict()

2021. 1. 6. 16:37·Programming/Python
반응형

평소에 파이토치에서 모델을 저장할 때 torch.save(모델명, 모델 경로)만 사용해서 pickle 파일로 저장을 했었다.

이번에 딥러닝을 배우면서 state_dict 함수란것을 배웠고, 그것이 뭔지 좀 더 자세히 기록하기 위해 이와 같은 포스팅을 하게 되었다. 

 

모델을 저장하는데에는 두가지 방법이 있는 듯 하다. 

 

  • torch.save(model 명, 저장경로)
    • 사실 torch.save()는 모델 뿐 아니라 모든 객체를 pickle 파일로 저장할 수 있는 함수이다. 확장자 명 또한 사용자가 지정 가능하다. 
    • 저장된 모델을 불러오는 데에는 torch.load(저장경로)를 사용하면 된다. 
  • torch.save( [model명].state_dict(), 저장경로)
    • 모델의 매개변수들을 저장하는 방법
    • 모델을 불러올 때에는 [새로운 모델 명].load_state_dict()를 사용해서 매개변수를 적용한다. 

내가 포스팅하려는 것은 이 두번째 방식이다. 

우선 state_dict가 무엇인지 부터 알아보도록 하자.

 

state_dict란?

state_dict란 torch.nn.Module에서 모델로 학습할 때 각 layer마다 텐서로 매핑되는 매개변수(예를 들어 가중치, 편향과 같은)를 python dictionary 타입으로 저장한 객체이다. 학습 가능한 매개변수를 갖는 계층만이 모델의 state_dict에 항목을 가진다. 

torch.optim 또한 옵티마이저의 상태 뿐만 아니라 사요된 하이퍼 매개변수 정보다 포함된 state_dict를 갖는다.

 

한마디로 모델의 구조에 맞게 각 레이어마다의 매개변수를 tensor형태로 매핑해서 dictionary형태로 저장하는 것이다.

왜 그냥 torch.save()로 저장하는 것이 아니라 state_dict()로 저장하는지는 잘 모르겠지만... 조금 다르게 저장하고 불러오는 느낌이다. 

 

아래 예시는 참고링크에 있는 코드와 설명을 참고하여 작성했다. 

추론(inference)를 위해 모델 저장하기 & 불러오기

state_dict 저장하기 / 불러오기 (권장)

저장하기 : 

torch.save(model.state_dict(), PATH)

불러오기 : 

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
추론을 위해 모델을 저장할 때는 학습된 모델의 학습된 매개변수만 저장하면 됩니다. torch.save() 를 사용하여 모델의 state_dict 를 저장하는 것이 나중에 모델을 사용할 때 가장 유연하게 사용할 수 있는, 모델 저장 시 권장하는 방법입니다.

PyTorch에서는 모델을 저장할 때 .pt 또는 .pth 확장자를 사용하는 것이 일반적인 규칙입니다.
추론을 실행하기 전에는 반드시 model.eval() 을 호출하여 드롭아웃 및 배치 정규화를 평가 모드로 설정하여야 합니다. 이것을 하지 않으면 추론 결과가 일관성 없게 출력됩니다.

 

전체 모델 저장하기/불러오기

저장하기 :

torch.save(model, PATH)

불러오기 : 

# 모델 클래스는 어딘가에 반드시 선언되어 있어야 합니다
model = torch.load(PATH)
model.eval()
이 저장하기/불러오기 과정은 가장 직관적인 문법을 사용하며 적은 양의 코드를 사용합니다. 이러한 방식으로 모델을 저장하는 것은 Python의 pickle 모듈을 사용하여 전체 모듈을 저장하게 됩니다. 하지만 pickle은 모델 그 자체를 저장하지 않기 때문에 직렬화된 데이터가 모델을 저장할 때 사용한 특정 클래스 및 디렉토리 경로(구조)에 얽매인다는 것이 이 방식의 단점입니다. 대신에 클래스가 위치한 파일의 경로를 저장해두고, 불러오는 시점에 사용합니다. 이러한 이유 때문에, 만들어둔 코드를 다른 프로젝트에서 사용하거나 리팩토링 후에 다양한 이유로 동작하지 않을 수 있습니다.

PyTorch에서는 모델을 저장할 때 .pt 또는 .pth 확장자를 사용하는 것이 일반적인 규칙입니다.

추론을 실행하기 전에는 반드시 model.eval() 을 호출하여 드롭아웃 및 배치 정규화를 평가 모드로 설정하여야 합니다. 이것을 하지 않으면 추론 결과가 일관성 없게 출력됩니다.

 

이렇게 다른 점을 이야기 해주는데, 현업에서 내가 사용할 때에는 torch.save()로 모델 자체를 저장하는 것 만으로도 큰 문제가 없었다. 오히려 그것을 더 많이 사용했다. 

그 모델의 class가 있으면 다른 프로젝트에서 사용하더래도 큰 문제가 없을 것이라고 예상하는 바이다. 

나중에 추가로 알게되는 사항이 있으면 재포스팅해야겠다. 

 

 

참고 : 

tutorials.pytorch.kr/beginner/saving_loading_models.html

 

모델 저장하기 & 불러오기 — PyTorch Tutorials 1.6.0 documentation

Note Click here to download the full example code 모델 저장하기 & 불러오기 Author: Matthew Inkawhich번역: 박정환 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다

tutorials.pytorch.kr

 

반응형

'Programming > Python' 카테고리의 다른 글

[Python] 파이썬 파일 크기 사이즈 구하기 - os.path.getsize()  (0) 2021.01.12
[Python] 폴더 내 파일 리스트 가져오기 (os/glob)  (0) 2021.01.12
[Python] isinstance 함수 - 파이썬 자료형 확인하는 함수  (0) 2021.01.06
[Python] enumerate 함수  (0) 2020.12.31
[Python] Colab이란? Colab 구글 드라이브에서 사용하기  (0) 2020.12.28
'Programming/Python' 카테고리의 다른 글
  • [Python] 파이썬 파일 크기 사이즈 구하기 - os.path.getsize()
  • [Python] 폴더 내 파일 리스트 가져오기 (os/glob)
  • [Python] isinstance 함수 - 파이썬 자료형 확인하는 함수
  • [Python] enumerate 함수
자동화먹
자동화먹
많은 사람들에게 도움이 되는 생산적인 기록하기
    반응형
  • 자동화먹
    자동화먹의 생산적인 기록
    자동화먹
  • 전체
    오늘
    어제
    • 분류 전체보기 (144)
      • 생산성 & 자동화 툴 (30)
        • Notion (24)
        • Obsidian (0)
        • Make.com (1)
        • tips (5)
      • Programming (37)
        • Python (18)
        • Oracle (6)
        • Git (13)
      • AI Study (65)
        • DL_Basic (14)
        • ML_Basic (14)
        • NLP (21)
        • Marketing&Recommend (4)
        • chatGPT (0)
        • etc (12)
      • 주인장의 생각서랍 (10)
        • 생각정리 (4)
        • 독서기록 (6)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    머신러닝
    dl
    노션첫걸음
    빅데이터분석
    Transformer
    notion
    git
    데이터분석
    빅데이터
    Google Cloud Platform
    git commit
    LSTM
    python기초
    pytorch
    자연어처리
    cnn
    Github
    GPT
    노션
    nlp
    파이토치로 시작하는 딥러닝 기초
    딥러닝
    gcp
    ML
    파이토치
    데이터베이스
    seq2seq
    Python
    기초
    Jupyter notebook
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
자동화먹
[Python] pytorch 모델 저장하기 - state_dict()
상단으로

티스토리툴바