ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [논문 리뷰] Perceiver AR: general-purpose, long-context autoregressive generation ICML 2022
    Deep-learning 2022. 9. 27. 11:39

    ─ 들어가며 ─

    Deep mind에서 이번 ICML에 나온 perceiver AR 논문 리뷰입니다.

    Transformer를 사용한 Autoregressive generation에서 매우 긴 길이에 대해서 memory efficient하게 좋은 성능을 내도록 개량한 모델입니다.

    Transformer는 sequence modeling을 굉장히 잘 하는 모델로 이미 소문이 나 있습니다. 다만, sequence 길이가 길어질수록 sequence 길이의 제곱에 비례하는 memory가 필요하다는 단점이 있습니다.

    따라서, 긴 길이의 sequence modeling은 큰 기업에서나 넘볼 수 있는 수준이었고, 그것도 책 정도로 아주 긴 길이에서는 한번에 학습이 불가능하고 context를 임의로 나누는 것이 불가피하다고 여겨져왔습니다.

     

    이번에 나온 Perceiver AR은 attention 계산 구조를 살짝 바꿈으로써, 매우 긴 길이에서의 sequence modeling을 성공해냈고, 여러 domain에서 안정적으로 sequence를 잘 만들어내는 방법을 제안하였습니다.

     

    아래 대부분의 내용들은 원 논문과 공식 ICML slides를 참고하였습니다.

    https://www.deepmind.com/publications/perceiver-ar-general-purpose-long-context-autoregressive-generation

     

    Perceiver AR: general-purpose, long-context autoregressive generation

    Real-world data is high-dimensional: a book, image, or musical performance can easily contain hundreds of thousands of elements even after compression. However, the most commonly used autoregressive models, Transformers, are prohibitively expensive to scal

    www.deepmind.com

     

    ─ 1. Attention in Transformer ─

    Transformer의 sequence modeling에서 핵심은 attention입니다.

     

    Attention의 계산 방식은 아래와 같습니다.

    1) 먼저 sequence 두 개$P=\{p_1, p_2, p_3, ..., p_m\}$, $Q=\{q_1, q_2, q_3, ..., q_n\}$가 필요합니다.

    (여기서 둘이 같은 sequence를 사용하면 self-attention이고, 다른 sequence를 사용하면 cross-attention입니다.)

    2) $P$가 query sequence고, $Q$가 key/value sequence라고 하면, $W_{query}(P), W_{key}(Q), W_{value}(Q)$를 계산하며, 보통 $W$는 bias=False인 linear layer입니다. $W_{query}(P): [B, m, H]$, $W_{key}(Q): [B, n, H]$
    3) $W_{query}(p_1)$과 $W_{key}(q_1)$과의 내적, $W_{query}(p_1)$과 $W_{key}(q_2)$, $W_{query}(p_1)$과 $W_{key}(q_3)$, ... $W_{query}(p_1)$과 $W_{key}(q_n)$의 내적 n개를 모두 구한 후 softmax합니다. 내적 결과: $[B, m, n]$

    (여기서 $p_1$과 $q_{\{1, 2, ..., n\}}$ 중에서 softmax 결과값이 높은 $q_i$이 "attention이 높게 잡힌 key다"라고들 표현합니다.)

    내적 결과와 $W_{value}(Q)$의 matmul로 attention 결과가 생성됩니다.

    $[B, m, n] \cdot [B, n, H] = [B, m, H]$

     

    이 결과는 길이는 $P$의 length m이지만, 갖고 있는 context는 $Q$의 context를 갖고 있습니다.

    내적 결과가 weight의 역할을 하고, weight에 matmul되는 것이 $W_{value}(Q)$이기 때문입니다.

     

    이 계산 과정을 위해서는 위 $[ ]$로 감싼 shape에서 설명했듯, $B$개의 batch size와, $m*n$개의 $H$ dimension vector가 필요합니다. 

    memory를 많이 필요로 하는 원인의 핵심은

    $$O(mn)$$

    의 memory complexity입니다.

    $m$과 $n$이 커질수록, self-attention에서는 두 sequence가 같으니, sequence 길이의 제곱에 비례해서 memory를 필요로 하게 됩니다.

     

    빅테크 기업에서 많이 시도하듯이, GPU의 양으로 밀어붙이려고 해도, 최소한 batch size 1만큼은 올릴 수 있어야 합니다.

    https://icml.cc/media/icml-2022/Slides/17886.pdf

    하지만 위 사진과 같이 batch size 1에서도 10k 길이 이상의 attention 계산에 Out Of Memory가 발생한다고 이야기합니다.

    이래서는 10k 이상의 길이에 대한 full-attention은 multi-gpu로도 process할 수 없습니다.

     

    이에 Perceiver AR에서는,

    Attention을 계산하는 query, key를 바꿈으로써 memory issue를 해결하고, 매우 긴 길이의 autoregressive generation이 가능한 Transformer 모델을 제안하였습니다.

     

    ─ 2. Perceiver AR ─

    self-attention은 앞서 설명했듯,  $O(L^2)$ memory complexity가 필요한 network입니다.

    길이에 제곱하는 memory complexity 때문에 긴 sequence의 데이터를 모델로 하여금 학습시키기 어렵습니다.

     

    이것을 위해서 self attention을 2단계로 나눔으로써 sequence의 receptive field는 sequence 전체로 두면서, memory적으로는 효율적인 Transformer를 설계하였습니다.

     

    1단계: $L$ 길이를 $M$과 $N$으로($L = M + N$) 나눕니다.

    전체 $L$ sequence를 앞의 $M$과 나머지 $N$개로 나누며, $N$이 작은 것이 핵심입니다.

    Sequence를 $P=\{p_1, p_2, ..., p_m, p_{m+1}, ..., p_L\}$라고 하고, 

    $$P=\{p_1, p_2, ..., p_L\}, \\ P^N=\{p_{m+1}, ..., p_L\}$$

    로 두었을 때,

    $P$가 key-value이고, $P^N$이 query인 causal cross-attention을 수행합니다.

    sequence 후반부에 대해서만 query화

    cross-attention이라곤 하지만 사실상 자기 자신 sequence를 보고있으므로 self attention처럼 여기는 것이 더 이해가 쉬울 것이라 생각합니다.

    causal은, query $P^N$가 $P$를 볼 때 qurey index가 key index보다는 크거나 같게 한다는 의미입니다.

    그렇지 않으면 inference에서는 미래의 sequence sample을 필요로하게 되므로 training 환경과 consistent한 inference가 불가능합니다.

     

    여기서 필요한 memory의 complexity는 $O(LN)$이 되어 $N$을 hyperparameter로 비교적 작은 수로 둔다면, $L$에는 linear하게 비례하는 memory complexity로 제한할 수 있습니다.

    (논문에서는 1024부터, 최대 16384까지 높게 잡았습니다. 결국 한 GPU에 1 batch size만 올릴수 있으면 Multi GPU를 써서 저자본인(google)들은 GPU 개수로 때려박을수 있다는 자신감...)

    그 결과 생성되는 attention 결과 길이는 $N$이고, context는 $P$ sequence 전체 정보를 갖고 있는 sequence $Z$를 얻게 됩니다.

     

    2단계: $Z$의 self-attention

    $Z$는 $N$길이의 sequence입니다. $N$은 hyperparameter로 $N$ 길이 sequence의 self-attention은 sequence 전체에 대해 attention을 수행하는 것보다 memory적으로 부담이 더 적습니다.

     

    윗부분 주목, 비교적 적은 수의 self-attention

    이렇게 attention을 설계하는 것은 기존의 Transformer attention 구조에서 self-attention query를 sequence 후반부만 적용하는 것과 equivalent합니다.

     

    ─ 3. 실험 결과 ─

    1) 속도 / memory 효율성

    효율적인 Perceiver AR

    파란색이 Perceiver AR의 결과이고, 나머지 색깔이 Transformer / Transformer-XL의 case입니다.

    모두 TPUv4 한대에서 batch-size 1로 진행한 결과이고, 모델의 사이즈가 커질수록, training 속도가 감소하지만, Perceiver AR이 긴 length sample에 대한 학습에서 더 빠른 것을 확인하실 수 있습니다.

    위 그림에서 이미 설명하였지만, Transformer / Transformer XL에서는 10k길이 이상을 넘기지 못하고 OOM이 뜬 것에 비하면 Perceiver AR은 32k에서도 대부분 OOM없이 견뎌내는 모습을 그래프상에서 확인할 수 있습니다.

     

    2) Extremely-long sequence 실험 결과

    ImageNet 결과
    book(PG-19) dataset

    위 표는 ImageNet을 64*64*3(RGB)로 downsampling 후 generation task를 Perceiver AR을 적용한 것이고, 아래는 Project Gutenberg dataset에 대한 결과입니다.

     

    두 결과 모두 모델이 "보다 확신을 갖고 선택하는 지표"상(bit per dim / perplexity)에서 더 좋은 결과를 보임을 확인하였습니다.

     

    3)  Music generation

    https://magenta.tensorflow.org/perceiver-ar

     

    Autoregressive long-context music generation with Perceiver AR

    We present our work on music generation with Perceiver AR, an autoregressive architecture that is able to generate high-quality samples as long as 65k tokens...

    magenta.tensorflow.org

    Google Magenta에서 준비한 demo에서는 긴 길이의 음악 생성이 가능함을 보여주고 있습니다.

    기존 다른 논문의 music generation demo에서는 주로 3분을 잘 넘기지 못했는데 이 demo page에서는 6~7분의 sample이 잘 생성된 것을 확인하실 수 있습니다.

     

     

    ─ 4. 정리 ─

    어려운 수학적 기반이나 알고리즘 없이 Transformer에 대한 지식만 갖고있다면 쉽게 적용해볼 수 있고 성능 면에서도 긴 길이에 대해 좋은 결과를 내는 Perceiver AR이 제안되었습니다.

     

    아마 지금 이후의 아직 발표되지 않은 GPT 등의 초거대 모델은 Perceiver AR 구조를 활용해서 엄청나게 긴 context length의 정보를 몇 천대의 multiGPU를 사용하여 학습시키는 방향이 될 것이라 예상합니다.

    댓글

Designed by Tistory.