DeepSpeed 논문, ZeRO: Memory Optimization Towards Training A Trillion Parameter Models 리뷰

2023. 11. 29. 16:22·LLM/LLM

ZeRO: Memory Optimization Towards Training A Trillion Parameter Models 리뷰

https://arxiv.org/abs/1910.02054

Abstract

  • 현재 큰 모델을 학습시키는 방법은 매우 제한되어 있다. 메모리가 낭비되거나 연산이 늦어지는 등의 문제점이 존재한다.
    • Data Parallelism은 메모리가 매우 redundant하다.
    • Model Prallelism은 communication 비용이 매우 높아 연산 효율이 안좋다.
  • We develop a novel solution, Zero Redundancy Optimizer (ZeRO), to optimize memory, achieving both memory efficiency and scaling efficiency.

1 Extended Introduction

  • Model Parallelism을 통해 큰 모델을 학습시키는 것은 굉장히 힘든데, 이렇게 가정해보자
    • 1 Trillion Parameter를 가지는 모델을 학습시키면 한 노드에 20B씩 학습이 가능할 때 50노드가 필요하고 DGX-2 노드는 16GPU이니까 800-way parallelism..이 된다.
  • 효율적으로 학습시키려면 어떻게 되었든 Memory Redundancy를 잡아야한다.
  • 메모리는 대부분 아래와 같은 요인으로 인해 낭비된다.
    • optimizer states (Adam Optimizer와 같은 경우에 momentum과 variance)
    • gradients
    • parameters
    • 이 요인들을 OGP라 통칭
  • 그래서 ZeRO는 위 세개를 전부 다 나눠버렸다.
  • Optimization Stage를 세개로 가져감
    • Partitioning Optimizer States
    • Partitioning Graidents
    • Partitioning Parameters
  • ZeRO에서 Optimizer States만 최적화한 것을 ZeRO-OS라고 부른다.
  • 결과적으로
    • ZeRO-OS에서 학습하는 모델은 6B 정도는 V100에서 학습가능하게 만들었다. (기존에는 1.5B 정도가 한계)
    • Model Parallelism과 같이 100B정도까지 학습가능해진다. MegaTron은 20B정도 가능하다.
    • GPT-like 모델에 대해서 1.5B ~ 100B까지 6x 정도 throughput 향상을 가져왔다.

2 Background

진짜 Model Parallelism, Data Parallelism 설명이라 건너뜀

3 Where did all the memory go?

  • 1.5B정도의 GPT-2 모델을 학습시키면 16-bit training때 3GB정도 weight만 저장한다. 근데 왜 32GB 메모리인 V100에서 학습하기가 어려울까?
  • 대부분의 메모리를 사용하는 것
    • Activations
    • OGP States
    • Temporary Buffers
  • 뒤의 둘을 Optimize한다

3.1 Optimizer States, Gradients and Parameters

  • Mixed Precision Training
    • Parameter, Activation은 FP16으로 저장되고 high throughput을 보여준다.
    • 하지만 backward propagation을 제대로 계산하기 위해서 fp32버전의 parameter와 optimizer states도 들고 있어야 한다.
  • ADAM의 예시
    • Moel Parameter 개수: ψ
    • FP16 param: 2ψ bytes, FP16 Gradients: 2ψ bytes
    • FP32 Copy, param: 4ψ bytes, Momentum: 4ψ bytes, Variance: 4ψ bytes
    • 총 16ψ16 bytes.
    • GPT-2 (1.5B) 모델의 경우에 24GB의 메모리가 “최소한” 필요함

3.2 Temporary Buffers

  • Gradient All Reduce, Graident Norm등에서 buffer가 필요함
  • 전부 flatten되어서 주고 받아야하므로 4ψ bytes가 필요.
  • GPT-2의 경우 6GB의 메모리가 필요함

4 ZeRO: Insights and Overview

  • Efficiency는 아래 세개의 key insight에서 온다.
    • Data parallelism은 scaling efficiency가 더 좋다. 그 이유는 model parallelism은 computing을 복잡하게 만들면서 communication overhead를 늘리기 때문
    • Data parallelism은 model states를 전부 다 저장하기 때문에 memory inefficient하다. 그래도 Model Parallelism은 Memory Efficient하다.
    • Model Parallelism과 Data parallelism은 Model States를 Training time동안 전부 저장한다. 하지만 계속해서 매 시간마다 필요한 것은 아니다.
  • ZeRO는 그래서 OGP States를 replicating하는 대신 partition한다.

5 ZeRO: Memory Optimization

5.1 Pos: Optimizer State Partitioning

  • Nd의 Data parallelism degree라 할 때 optimizer states를 Nd로 똑같이 나눈다. 그리고 i_th data parallel process는 optimizer states를 해당 번호만 바꾼다. 그래서 optimizer states를 1Nd만 들고 있으면 된다.
  • all-gather를 하게 되면 전체 optimizer states가 나온다
  • 큰 Nd에 대해서는 4ψ에 근사하는 memory reduction을 보여준다.

여기서 든 의문은 data parallelism degree라고 하는 것은 결국 data parallel process별로 다른 배치를 들고 있을텐데 그럼 optimizer states가 서로 달라지지 않나?라는 것이다. 서로 다른 데이터에 대해 다른 gradient가 잡히지 않을까?

5.2 P_g: Gradient Partitioning

  • Optimizer States를 나누어놓았으니 Gradients도 나누어 놓는 것이 좋다. 
  • 그래도 Backward는 똑같이 해야하니 Reduce Scatter를 bucketization strategy와 함께 사용한다.
    • 각 프로세스= 1 bucket

5.3 P_p: Parameter Partitioning

  • Partition 밖의 Parameter는 forward, backward를 위해 필요하긴하다
  • 그래도 그런 것들을 줄이기 위해 broadcast를 통해 적절한 data parallel process로부터 받아서 계산한다
    • 이거보면 data parallel이어도 그냥 다 같이 계산하나?

5.4 C_B: Constant Buffer Size

  • (기존방식) 모델 파라미터값을 전송할때 어떤 큰 단위로 전달
  • (Constant Buffer Size 방식) 전달하는 단위를 특정 상수값으로 설정해놓음.  근데 크기가 그 이상이 되면 어떻게 처리하지?
  • 결국 Nd가 계속해서 커지면 8x까지 줄어든다

6 ZeRO: Communication Overhead

  • All Reduce = Reduce Scatter + All Gather이므로 2ψ만큼 데이터가 움직인다
  • Communication Overhead of Pos+g: Scatter Reduce (for Parameter update) + All Gather이므로 2ψ 만큼 데이터가 움직인다
  • Communication Overhead of Pos+g+p: forward propagation 때 parameter를 all gather로 주고받고 쓴 다음에는 버린다. 그 다음 backward 때는 역방향 1ψ-> 이거 잘 이해안감
  • 총 3ψ의 Communication Overhead. 기존과 비교하면 1.5x만의 overhead

7 ZeRO & Model Parallelism

  • ZeRO 쓰면 Model Parallelism은 조금 덜 필요함
  • 그래도 도움이 될 수 있을지도 모른다. 하지만 너무 힘든 작업

 

'LLM > LLM' 카테고리의 다른 글

[RAG] 논문요약 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks  (1) 2024.02.16
Llama2 모델 파인튜닝 fine tuning_autotrain  (0) 2023.12.21
Alpaca는 LLama모델로 만들어졌다는거 알아요?  (1) 2023.10.18
'LLM/LLM' 카테고리의 다른 글
  • [RAG] 논문요약 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
  • Llama2 모델 파인튜닝 fine tuning_autotrain
  • Alpaca는 LLama모델로 만들어졌다는거 알아요?
버터젤리
버터젤리
  • 버터젤리
    across the universe
    버터젤리
  • 전체
    오늘
    어제
    • 분류 전체보기 (128) N
      • 데이터 엔지니어 (0)
        • MLOPs (0)
      • 인프라 및 클라우드 (0)
        • Docker (0)
        • Kubernetes (0)
      • Development(개발) (2) N
        • Django (0)
        • 개발 Core (2) N
      • LLM (4)
        • 강화학습 (0)
        • LLM (4)
        • Generator (0)
      • PM (7)
        • IT Trends (0)
        • 세미나 후기 (7)
      • Deep learning (30)
        • 기초이론 (8)
        • 컴퓨터비전 (6)
        • 자연어처리 (5)
        • Anomaly Detection (6)
      • Machine learning (25)
      • Computer Science (26)
        • Linux (21)
        • 네트워크 (1)
        • 하드웨어 (4)
        • 운영체제(OS) (0)
      • 프로그래밍 언어 (17)
        • Python (8)
        • Pytorch (8)
        • Tensorflow (0)
      • Tools (14)
        • 주피터노트북 (7)
        • 깃(Git) (2)
        • 파이참 (5)
      • Book (2)
      • LIFE (0)
      • 창고 (0)
        • AI 인턴 (0)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    책임
    객사오
    batch normalization
    BN
    역할과책임
    jupyer notebook
    #git#github
    역할
    객체지향의사실과오해
    데코레이터
    @
    객체지향
    리눅스#파일이동#특정이름#포함
    nohup
    BatchNormalization
    백그라운드
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.1
버터젤리
DeepSpeed 논문, ZeRO: Memory Optimization Towards Training A Trillion Parameter Models 리뷰
상단으로

티스토리툴바