MEGABYTE Predicting Million-byte Sequences with Multiscale Transformers

개인의견 : 패치 기반의 병렬 연산과 Cross-Attention을 통한 글로벌-로컬 상호작용, 그리고 토크나이저 free? 한 신기한 논문으로 보인다.

논문 : https://arxiv.org/abs/2305.07185

아래글은 Scholar GPT와 대화를 나눈 내용입니다.

Q : 논문의 핵심 내용과 강점, 알고리즘 설명, 그리고 한계점

Answer

논문 요약: MEGABYTE의 강점과 독창성

이 논문은 MEGABYTE라는 모델을 소개하며, 기존의 Transformer 모델이 긴 시퀀스(수백만 바이트 단위)의 데이터를 효과적으로 다루기 어려운 한계를 해결하고자 합니다. MEGABYTE는 시퀀스를 패치(patch) 단위로 분할한 뒤, 글로벌(Global) 및 로컬(Local) 모델을 활용해 더 큰 시퀀스에서도 효율적으로 학습할 수 있도록 구성되었습니다. 이 구조는 메모리 효율성 및 학습 속도를 크게 개선합니다.

강점

  1. 효율적인 메모리 및 계산 비용: MEGABYTE는 자가 주의(attention)의 비용을 부분적으로 줄이고, 패치 단위로 계산하여 대규모 시퀀스의 경우에서도 자원 소모를 최적화합니다.
  2. 병렬 처리 가능: 글로벌 모델을 통해 주요 정보를 캡처하고, 이를 로컬 모델에서 반복 사용함으로써 디코딩 단계에서 병렬 처리가 가능해 속도가 빨라집니다.
  3. 토크나이제이션이 필요 없는 구조: 일반적으로 긴 텍스트 시퀀스를 처리할 때 토크나이제이션을 사용하지만, MEGABYTE는 원시 바이트 수준의 시퀀스를 직접 모델링하여 다국어 및 여러 도메인에 쉽게 적용될 수 있습니다.

독창적인 지점

  • 패치 기반 모델링: 시퀀스를 정해진 크기의 패치로 나누어 글로벌 모델이 패치 간의 연관성을 학습하고, 로컬 모델이 패치 내부의 세부 정보를 학습하게 설계한 것이 독창적입니다.
  • Cross-patch Attention과 Strided Inference: 패치 간의 연관성을 활용해 학습 효율성을 높이는 Cross-patch Attention과 예측 정확성을 보완하는 Strided Inference 방법을 통해 긴 시퀀스 처리의 효율성과 정확도를 모두 강화하였습니다.

핵심 알고리즘의 예시 설명

  1. 패치 임베더 (Patch Embedder):
    • 바이트 시퀀스를 받아 지정된 크기 (P)로 패치하여 각 바이트를 임베딩합니다.
    • 예를 들어, 시퀀스가 “mebyte”이고 패치 크기 (P)가 2라고 할 때, “me”, “by”, “te”로 패치됩니다.
  2. 글로벌 모델:
    • 각 패치를 하나의 토큰으로 간주하고, 이전 패치들과의 관계를 학습하는 자가 주의 메커니즘을 사용해 문맥적 패치 표현을 생성합니다.
    • “me”, “by”, “te” 패치를 가지고 각 패치의 문맥적 정보를 모델링하여 최종적으로 전체 시퀀스에 대한 글로벌 정보를 캡처합니다.
  3. 로컬 모델:
    • 글로벌 모델의 출력을 바탕으로 각 패치 내부의 바이트 간 연관성을 학습합니다.
    • 예를 들어 “me” 패치에서 ‘m’ 다음에 ‘e’가 나올 확률을 예측하는 방식으로 작동합니다.
  4. 출력 예측:
    • 최종적으로 로컬 모델은 각 바이트에 대한 확률을 계산하여 다음 바이트를 예측합니다. 예를 들어 “m”이 주어졌을 때 ‘e’가 나올 확률을 계산하는 방식입니다.

논문의 한계점

  1. 작은 패치에서의 정보 손실 가능성: 패치 단위로 시퀀스를 나누는 방식은 특정 패치의 정보가 충분히 전달되지 않거나 학습되지 않을 수 있습니다.
  2. 상대적으로 작은 스케일의 실험: 이 논문에서는 MEGABYTE의 가능성을 증명했지만, 더 큰 데이터셋과 모델 크기로의 확장이 이루어지지 않아 이 방식이 실제 대규모 환경에서도 동일한 성능을 보장하는지 검증이 필요합니다.
  3. 복잡성: MEGABYTE는 두 개의 모델(글로벌 및 로컬 모델)을 결합한 복잡한 구조로 인해 구현 및 조정이 까다로울 수 있으며, 하이퍼파라미터 조정에 상당한 리소스가 필요할 수 있습니다.

MEGABYTE는 매우 긴 시퀀스 모델링에서 높은 성능과 효율성을 보여주지만, 몇 가지 보완점을 통해 더욱 발전할 가능성이 있습니다.

Q : 패치를 어떻게 분할하는지 자세하게 설명해 줄 수 있니? 그리고 토크나이제이션이 없다는게 진짜로 low byte를 입력으로 넣는다는거지? 그리고 병렬처리를 하는 방법에 대해서도 자세하게 설명 부탁해

answer:

1. 패치 분할 방법

MEGABYTE는 긴 시퀀스를 고정된 크기의 패치로 분할하여 각 패치를 개별 단위로 학습합니다. 이 방식은 특히 긴 시퀀스를 처리할 때 효율성을 높입니다. 구체적인 분할 과정은 다음과 같습니다.

  1. 입력 시퀀스를 바이트 수준으로 나눔: 예를 들어 텍스트 “megabyte”가 있다면, 이를 문자 하나하나가 아닌 바이트 하나하나로 나누어 처리합니다.
  2. 고정된 패치 크기 지정: 사용자가 설정한 ( P ) 값에 따라 바이트 시퀀스를 고정된 크기의 블록(패치)으로 나눕니다. 예를 들어, 패치 크기 ( P )가 4라면, “megabyte” 시퀀스는 다음과 같은 패치로 나뉩니다: [mega], [byte].
  3. 각 패치 임베딩: 각 패치는 그 자체로 독립적인 정보 단위로 처리되며, 각 바이트가 임베딩 벡터로 변환되어 이 벡터들이 모여 패치 임베딩이 생성됩니다. 이를 통해 패치는 고차원 임베딩 공간에서 독립적인 벡터로 표현됩니다.
  4. 글로벌 및 로컬 모델로 전달: 글로벌 모델은 각 패치를 전체 시퀀스 맥락에서 이해하고, 로컬 모델은 패치 내부의 바이트 간 상호작용을 학습하는 방식으로 구조화됩니다.

2. 토크나이제이션 없이 Low Byte 입력

MEGABYTE 모델은 토크나이제이션을 사용하지 않고 바이트 수준의 원시 데이터를 입력으로 받는 방식을 채택하고 있습니다. 이는 다음과 같은 의미를 가집니다.

  • 바이트 수준 입력: 토크나이제이션을 통해 큰 단어를 토큰으로 변환하는 기존 접근과 달리, MEGABYTE는 원시 바이트 데이터 (예: UTF-8에서 각 문자에 해당하는 1바이트 혹은 멀티바이트) 자체를 입력으로 받아 모델링합니다. 따라서 “m”, “e”, “g”, “a”, “b”, “y”, “t”, “e” 각 문자를 각각의 바이트로서 그대로 처리합니다.
  • 언어 독립성: 일반적인 NLP에서는 언어별 토크나이저가 필요하지만, MEGABYTE는 언어 및 도메인에 관계없이 원시 바이트를 다루므로 다양한 종류의 데이터를 일관되게 처리할 수 있습니다.
  • 정보 손실 감소: 토크나이제이션 과정에서 생길 수 있는 정보 손실을 방지하여 원본 데이터를 그대로 모델에 반영할 수 있습니다. 특히 다국어 데이터나 텍스트가 아닌 이미지, 오디오 같은 다양한 도메인에 유리한 방식입니다.

3. 병렬 처리 방법

MEGABYTE는 패치를 활용하여 병렬 처리를 효과적으로 수행할 수 있는 구조를 가지고 있습니다. 모델의 병렬 처리 방식은 다음과 같습니다.

  1. 글로벌 모델의 병렬성:
    • MEGABYTE의 글로벌 모델은 각 패치가 전체 맥락에서 어떻게 연결되는지를 학습합니다. 이는 각 패치 간 관계를 학습하는 역할을 하며, 각 패치의 임베딩은 독립적으로 계산될 수 있어 병렬 처리가 가능합니다.
    • 예를 들어, 시퀀스가 [patch1, patch2, patch3]로 분할되어 있을 때, 글로벌 모델은 각 패치의 상관 관계를 독립적으로 학습하여 동시에 여러 패치를 처리할 수 있습니다.
  2. 로컬 모델의 패치 내 병렬성:
    • 로컬 모델은 글로벌 모델의 출력을 사용해 개별 패치 내부의 바이트 간 상호작용을 학습합니다. 각 패치 내에서 바이트 단위의 예측을 수행하므로, 패치 내에서 병렬 처리가 가능합니다.
    • 패치 크기가 ( P )로 설정되어 있다면, MEGABYTE는 각 패치의 바이트들에 대해 동시에 계산을 진행할 수 있어 계산 속도를 높입니다.
  3. 병렬 디코딩:
    • 일반적인 Transformer 모델에서는 디코딩 단계에서 매 단계마다 이전 출력에 의존하여 순차적으로 계산해야 하지만, MEGABYTE는 패치 단위로 디코딩할 수 있습니다. 글로벌 모델이 먼저 전체 시퀀스의 맥락을 제공하므로 로컬 모델은 패치 내에서 병렬로 디코딩할 수 있습니다. 예를 들어, 패치1의 계산이 완료되면 곧바로 패치2의 계산이 시작될 수 있습니다.
    • MEGABYTE는 전체 시퀀스에 대해 글로벌 모델이 1회 계산되고 로컬 모델이 각 패치별로 병렬로 작동하므로, 긴 시퀀스에서 디코딩 속도가 크게 향상됩니다.

MEGABYTE의 구조는 긴 시퀀스를 패치로 나누고, 글로벌-로컬 모델을 병렬로 작동시킴으로써 기존 Transformer의 순차적 계산 한계를 극복하고자 한 점에서 독창적입니다.

Q : 병렬 디코딩에 대해서 자세한 이해를 하고싶은데 이 논문에서 decoding을 N step에서 N+1 step이 되면서 어떤 연산이 추가적으로 필요하고 어떤 데이터를 메모리에서 들고와야하는지를 자세하게 설명해 주겠니?

Answer :

MEGABYTE 논문에서 제안된 병렬 디코딩 방식은 글로벌(Global) 모델과 로컬(Local) 모델의 협력적 작동을 통해 기존 Transformer의 순차적 디코딩 방식과 달리 패치 단위 병렬 디코딩을 가능하게 합니다. 이 과정에서 각 단계(( N ) 단계에서 ( N+1 ) 단계로 넘어가는 것)마다 필요한 연산과 데이터 흐름을 보다 구체적으로 설명하겠습니다.

MEGABYTE 병렬 디코딩 과정의 핵심 구조

MEGABYTE는 시퀀스를 여러 개의 패치(Patch)로 나눈 뒤, 글로벌 모델과 로컬 모델이 협력하여 디코딩을 수행합니다. 글로벌 모델은 패치 간의 관계를 파악하고, 로컬 모델은 각 패치 내부의 바이트 간 상관성을 학습합니다. 글로벌 모델은 전체 맥락을 한 번에 계산하고 이를 저장하여 로컬 모델이 필요한 패치 정보를 병렬로 참조할 수 있게 합니다.

1. ( N ) 단계에서 ( N+1 ) 단계로 넘어가는 과정

디코딩 과정에서 시퀀스를 생성할 때, MEGABYTE는 특정 패치의 정보를 활용해 다음 단계로 진행하는데, 이때 글로벌 모델의 출력과 로컬 모델의 상태를 함께 사용합니다. 아래는 구체적인 과정입니다.

(1) 글로벌 모델의 연산 및 메모리 상 데이터 유지

  • 글로벌 모델은 모든 패치의 맥락을 한번에 처리하여 패치별 임베딩 표현을 생성합니다. 이 과정은 시퀀스의 처음 단계에 한 번 수행되며, 이후 반복적인 글로벌 모델 연산은 필요하지 않습니다.
  • 글로벌 모델이 생성한 각 패치에 대한 맥락적 임베딩 표현 ( h_{\text{global-out}} )은 메모리에 저장되어 로컬 모델이 이후 단계에서 참조할 수 있습니다.
  • 예를 들어, 패치가 ( K )개라면 ( h_{\text{global-out}} = { h_{\text{global-out}}^1, h_{\text{global-out}}^2, …, h_{\text{global-out}}^K } ) 형태로 모든 패치에 대한 글로벌 표현이 저장됩니다.

(2) 로컬 모델의 연산 및 메모리 상 데이터 흐름

로컬 모델은 각 패치 내부의 바이트를 예측하는 역할을 하며, 글로벌 모델의 출력을 참조하여 각 패치 내 바이트를 병렬로 예측합니다. 구체적으로는:

  • ( N ) 단계에서 패치 ( k )의 바이트 ( x_t^{(k)} )를 생성하기 위해 로컬 모델의 입력이 필요합니다. 이 입력은 글로벌 모델의 출력인 ( h_{\text{global-out}}^k )와 해당 패치의 이전 바이트에 대한 임베딩을 결합하여 얻어집니다.
  • ( N+1 ) 단계에서 패치 ( k )의 다음 바이트 ( x_{t+1}^{(k)} )를 생성할 때도 동일하게 글로벌 모델의 ( h_{\text{global-out}}^k )을 참조하되, 이 값은 이전 단계와 동일하므로 메모리에서 다시 불러오기만 하면 됩니다.
  • 로컬 모델은 한 번에 전체 패치 내 바이트를 병렬로 생성하므로, 각 패치의 내부 바이트들에 대한 연산이 병렬로 진행될 수 있습니다.

(3) 각 단계에서의 추가 연산

  • ( N ) 단계에서 ( N+1 ) 단계로 넘어갈 때, 로컬 모델은 이전에 계산된 글로벌 모델의 출력을 반복적으로 참조하기 때문에 추가적인 글로벌 모델 계산은 필요하지 않습니다.
  • 로컬 모델의 경우 각 패치에 대한 내부 바이트 간 관계를 학습하는데, 이를 위해 로컬 모델의 내부 상태를 갱신하여 다음 바이트에 대한 예측에 활용합니다. 이때 필요한 추가 연산은 로컬 모델의 마스킹된 자기 주의 연산으로, 이는 특정 패치 내에서 이전 바이트들에만 접근할 수 있도록 제약하는 역할을 합니다.

2. 메모리에서 참조해야 하는 데이터

디코딩 과정에서 메모리에서 참조해야 하는 주요 데이터는 다음과 같습니다.

  • 글로벌 모델의 출력 ( h_{\text{global-out}} ): 각 패치에 대한 맥락적 정보가 저장되어 있어, 로컬 모델이 다음 바이트를 생성할 때마다 이를 참조합니다. 이는 고정된 값이므로 초기 단계에서 한 번만 계산되고 이후엔 불러오기만 합니다.
  • 로컬 모델의 이전 출력 및 상태: 로컬 모델은 패치 내에서 이전 바이트들의 정보를 필요로 하므로, 각 패치 내 바이트의 임베딩 및 로컬 모델의 상태가 메모리에 유지됩니다.
  • 패딩 및 초기화 임베딩: 로컬 모델이 패치의 첫 바이트를 예측할 때 사용할 패딩 값과 초기화 임베딩을 참조합니다. 이는 글로벌 모델의 패치 패딩과 비슷한 방식으로 동작합니다.

요약

MEGABYTE의 병렬 디코딩에서는 글로벌 모델이 시퀀스의 처음 단계에서 모든 패치에 대한 정보를 한 번에 계산하여 메모리에 저장하고, 이후 각 패치별로 로컬 모델이 병렬로 작동하며, 각 패치 내의 바이트 예측을 수행합니다. 이를 통해 기존 Transformer의 순차적 디코딩 방식을 개선하여 병렬 디코딩을 가능하게 합니다.