Loading [MathJax]/jax/output/CommonHTML/jax.js

ABOUT ME

시체구덩이, 블루칼라형 개발자

Today
Yesterday
Total
  • [논문 리뷰] Perceiver AR: general-purpose, long-context autoregressive generation ICML 2022
    Deep-learning 2022. 9. 27. 11:39

    ─ 들어가며 ─

    Deep mind에서 이번 ICML에 나온 perceiver AR 논문 리뷰입니다.

    Transformer를 사용한 Autoregressive generation에서 매우 긴 길이에 대해서 memory efficient하게 좋은 성능을 내도록 개량한 모델입니다.

    Transformer는 sequence modeling을 굉장히 잘 하는 모델로 이미 소문이 나 있습니다. 다만, sequence 길이가 길어질수록 sequence 길이의 제곱에 비례하는 memory가 필요하다는 단점이 있습니다.

    따라서, 긴 길이의 sequence modeling은 큰 기업에서나 넘볼 수 있는 수준이었고, 그것도 책 정도로 아주 긴 길이에서는 한번에 학습이 불가능하고 context를 임의로 나누는 것이 불가피하다고 여겨져왔습니다.

     

    이번에 나온 Perceiver AR은 attention 계산 구조를 살짝 바꿈으로써, 매우 긴 길이에서의 sequence modeling을 성공해냈고, 여러 domain에서 안정적으로 sequence를 잘 만들어내는 방법을 제안하였습니다.

     

    아래 대부분의 내용들은 원 논문과 공식 ICML slides를 참고하였습니다.

    https://www.deepmind.com/publications/perceiver-ar-general-purpose-long-context-autoregressive-generation

     

    Perceiver AR: general-purpose, long-context autoregressive generation

    Real-world data is high-dimensional: a book, image, or musical performance can easily contain hundreds of thousands of elements even after compression. However, the most commonly used autoregressive models, Transformers, are prohibitively expensive to scal

    www.deepmind.com

     

    ─ 1. Attention in Transformer ─

    Transformer의 sequence modeling에서 핵심은 attention입니다.

     

    Attention의 계산 방식은 아래와 같습니다.

    1) 먼저 sequence 두 개P={p1,p2,p3,...,pm}, Q={q1,q2,q3,...,qn}가 필요합니다.

    (여기서 둘이 같은 sequence를 사용하면 self-attention이고, 다른 sequence를 사용하면 cross-attention입니다.)

    2) P가 query sequence고, Q가 key/value sequence라고 하면, Wquery(P),Wkey(Q),Wvalue(Q)를 계산하며, 보통 W는 bias=False인 linear layer입니다. Wquery(P):[B,m,H], Wkey(Q):[B,n,H]
    3) Wquery(p1)Wkey(q1)과의 내적, Wquery(p1)Wkey(q2), Wquery(p1)Wkey(q3), ... Wquery(p1)Wkey(qn)의 내적 n개를 모두 구한 후 softmax합니다. 내적 결과: [B,m,n]

    (여기서 p1q{1,2,...,n} 중에서 softmax 결과값이 높은 qi이 "attention이 높게 잡힌 key다"라고들 표현합니다.)

    내적 결과와 Wvalue(Q)의 matmul로 attention 결과가 생성됩니다.

    [B,m,n][B,n,H]=[B,m,H]

     

    이 결과는 길이는 P의 length m이지만, 갖고 있는 context는 Q의 context를 갖고 있습니다.

    내적 결과가 weight의 역할을 하고, weight에 matmul되는 것이 Wvalue(Q)이기 때문입니다.

     

    이 계산 과정을 위해서는 위 []로 감싼 shape에서 설명했듯, B개의 batch size와, mn개의 H dimension vector가 필요합니다. 

    memory를 많이 필요로 하는 원인의 핵심은

    O(mn)

    의 memory complexity입니다.

    mn이 커질수록, self-attention에서는 두 sequence가 같으니, sequence 길이의 제곱에 비례해서 memory를 필요로 하게 됩니다.

     

    빅테크 기업에서 많이 시도하듯이, GPU의 양으로 밀어붙이려고 해도, 최소한 batch size 1만큼은 올릴 수 있어야 합니다.

    https://icml.cc/media/icml-2022/Slides/17886.pdf

    하지만 위 사진과 같이 batch size 1에서도 10k 길이 이상의 attention 계산에 Out Of Memory가 발생한다고 이야기합니다.

    이래서는 10k 이상의 길이에 대한 full-attention은 multi-gpu로도 process할 수 없습니다.

     

    이에 Perceiver AR에서는,

    Attention을 계산하는 query, key를 바꿈으로써 memory issue를 해결하고, 매우 긴 길이의 autoregressive generation이 가능한 Transformer 모델을 제안하였습니다.

     

    ─ 2. Perceiver AR ─

    self-attention은 앞서 설명했듯,  O(L2) memory complexity가 필요한 network입니다.

    길이에 제곱하는 memory complexity 때문에 긴 sequence의 데이터를 모델로 하여금 학습시키기 어렵습니다.

     

    이것을 위해서 self attention을 2단계로 나눔으로써 sequence의 receptive field는 sequence 전체로 두면서, memory적으로는 효율적인 Transformer를 설계하였습니다.

     

    1단계: L 길이를 MN으로(L=M+N) 나눕니다.

    전체 L sequence를 앞의 M과 나머지 N개로 나누며, N이 작은 것이 핵심입니다.

    Sequence를 P={p1,p2,...,pm,pm+1,...,pL}라고 하고, 

    P={p1,p2,...,pL},PN={pm+1,...,pL}

    로 두었을 때,

    P가 key-value이고, PN이 query인 causal cross-attention을 수행합니다.

    sequence 후반부에 대해서만 query화

    cross-attention이라곤 하지만 사실상 자기 자신 sequence를 보고있으므로 self attention처럼 여기는 것이 더 이해가 쉬울 것이라 생각합니다.

    causal은, query PNP를 볼 때 qurey index가 key index보다는 크거나 같게 한다는 의미입니다.

    그렇지 않으면 inference에서는 미래의 sequence sample을 필요로하게 되므로 training 환경과 consistent한 inference가 불가능합니다.

     

    여기서 필요한 memory의 complexity는 O(LN)이 되어 N을 hyperparameter로 비교적 작은 수로 둔다면, L에는 linear하게 비례하는 memory complexity로 제한할 수 있습니다.

    (논문에서는 1024부터, 최대 16384까지 높게 잡았습니다. 결국 한 GPU에 1 batch size만 올릴수 있으면 Multi GPU를 써서 저자본인(google)들은 GPU 개수로 때려박을수 있다는 자신감...)

    그 결과 생성되는 attention 결과 길이는 N이고, context는 P sequence 전체 정보를 갖고 있는 sequence Z를 얻게 됩니다.

     

    2단계: Z의 self-attention

    ZN길이의 sequence입니다. N은 hyperparameter로 N 길이 sequence의 self-attention은 sequence 전체에 대해 attention을 수행하는 것보다 memory적으로 부담이 더 적습니다.

     

    윗부분 주목, 비교적 적은 수의 self-attention

    이렇게 attention을 설계하는 것은 기존의 Transformer attention 구조에서 self-attention query를 sequence 후반부만 적용하는 것과 equivalent합니다.

     

    ─ 3. 실험 결과 ─

    1) 속도 / memory 효율성

    효율적인 Perceiver AR

    파란색이 Perceiver AR의 결과이고, 나머지 색깔이 Transformer / Transformer-XL의 case입니다.

    모두 TPUv4 한대에서 batch-size 1로 진행한 결과이고, 모델의 사이즈가 커질수록, training 속도가 감소하지만, Perceiver AR이 긴 length sample에 대한 학습에서 더 빠른 것을 확인하실 수 있습니다.

    위 그림에서 이미 설명하였지만, Transformer / Transformer XL에서는 10k길이 이상을 넘기지 못하고 OOM이 뜬 것에 비하면 Perceiver AR은 32k에서도 대부분 OOM없이 견뎌내는 모습을 그래프상에서 확인할 수 있습니다.

     

    2) Extremely-long sequence 실험 결과

    ImageNet 결과
    book(PG-19) dataset

    위 표는 ImageNet을 64*64*3(RGB)로 downsampling 후 generation task를 Perceiver AR을 적용한 것이고, 아래는 Project Gutenberg dataset에 대한 결과입니다.

     

    두 결과 모두 모델이 "보다 확신을 갖고 선택하는 지표"상(bit per dim / perplexity)에서 더 좋은 결과를 보임을 확인하였습니다.

     

    3)  Music generation

    https://magenta.tensorflow.org/perceiver-ar

     

    Autoregressive long-context music generation with Perceiver AR

    We present our work on music generation with Perceiver AR, an autoregressive architecture that is able to generate high-quality samples as long as 65k tokens...

    magenta.tensorflow.org

    Google Magenta에서 준비한 demo에서는 긴 길이의 음악 생성이 가능함을 보여주고 있습니다.

    기존 다른 논문의 music generation demo에서는 주로 3분을 잘 넘기지 못했는데 이 demo page에서는 6~7분의 sample이 잘 생성된 것을 확인하실 수 있습니다.

     

     

    ─ 4. 정리 ─

    어려운 수학적 기반이나 알고리즘 없이 Transformer에 대한 지식만 갖고있다면 쉽게 적용해볼 수 있고 성능 면에서도 긴 길이에 대해 좋은 결과를 내는 Perceiver AR이 제안되었습니다.

     

    아마 지금 이후의 아직 발표되지 않은 GPT 등의 초거대 모델은 Perceiver AR 구조를 활용해서 엄청나게 긴 context length의 정보를 몇 천대의 multiGPU를 사용하여 학습시키는 방향이 될 것이라 예상합니다.

    댓글

Designed by Tistory.