ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [논문 설명] Learnable Fourier Features for multi-dimensional spatial positional encoding 2
    Deep-learning 2022. 9. 5. 21:46

    ─ 들어가며 ─

    Learnable Fourier Features for multi-dimensional spatial positional encoding 1 포스팅에서 계속됩니다.

     

    수식: 

    1. 저자들이 제안하는 positional encoding이 뭐가 좋나요? + 성능을 높이기 위한 추가사항

    ① continuous 위치 / Unseen 위치의 inference에서도 잘 동작한다.

    continuous 위치의 경우 예를 들면,

    train set에서 (2, 3) 위치와 (4, 5)위치는 등장하는데, discrete embedding에서는

    $(PE_x(2) || PE_y(3))$과  $(PE_x(4) || PE_y(5))$ ($ || $는 concatenate)로 표현됩니다.

    여기서 inference때, 만약 학습 과정에서 등장하지 않은 (3, 4)가 등장한다면?

    우리는 3이 2와 4 사이에 있다는 것을 알지만, 학습때 x=3 위치가 등장하지 않았다면 discrete embedding table은 각 정수별로 학습이 별개로 되기 때문에 위치 "3"이 "2"와 "4" 사이에 있는 것을 알지 못합니다. 하지만 제안하는 Positional encoding은 (이후 나오는 MLP layer에서 non-linearity가 추가되긴 하지만)  2와 4 사이에서 3이 있다는 interpolation이 가능합니다(위치는 달라지더라도 $W_r$은 항시 사용되기 때문).

     

    주변을 보는 우리의 직관 vs 돌만 보는 딥러닝

    Unseen 위치에 대해서도 비슷한 원리로 위치 정보를 학습하는 $W_r$을 공유하므로 discrete encoding보다 성능이 좋다고 합니다.

     

    ② Learnable하다.

    ①을 보완해주는 이야기인데, 결국 positional encoding은 원래의 data에 더해주는 additional vector로, 이후 attention에서 내적 계산시 attention score 크기에 관여합니다. 그 말은 dataset 상에서 서로 attention이 잘 안잡히는 거리끼리는 서로 attention이 낮아지게 하며, 잘 잡히는 거리끼리는 서로 attention이 높아지도록 학습될 것입니다.

    학습되는 weight가 아닌 Transformer의 삼각함수 positional encoding은 고정된 attention weight를 가지므로 dataset에 adaptive한 거리 정보를 학습하지 못합니다.

     

     

    ③ position space를 몇 개의 그룹으로 나눌 수 있다.

    (이후 실험 설명의 widget captioning task가 예시)

    예를 들어 space가 3차원이라고 하면, <x, y> space는 서로 한 그룹으로 두고, z space를 별개의 space로 두면, <x, y> space끼리는 거리 개념이 적용되고, z space는 서로 다른 차원으로 두어서 단지 concatenate한 형태로 두어 $L_2$ distance보다 더 복잡한 distance로 위치 가정을 할 수 있다고 저자들은 설명합니다.

     

    ④ 추가 성능 개선을 위한 MLP

    성능을 더 높이기 위해 저자들은 2 layers의 non-linearity를 추가했습니다. 이를 사용한 것이 사용하지 않은 것보다 더 성능이 좋았음을 실험적으로 증명했습니다(아래 참조).

     

    Non-linearlity는

    $$ PE_x = \phi(r_x, \theta) W_p$$

    여기서 $r_x$가 원래 제안하는 positional encoding이고, $\phi( ,\theta)$와 $W_p$의 두 layer를 거쳐(거창하게 적어놓긴 했지만, weight, bias 둘 다 있고 중간 activation이 GELU인 MLP입니다) positional encoding을 완성합니다.

    import torch.nn as nn
    
    nn.Sequential(nn.Linear(w1, w2, bias=True), nn.GELU(), nn.Linear(w2, w3, bias=True))

    위 코드로 간단하게 구현됩니다.

     

     

    2. 실험 설명

    ① Image generation task

    Reformer 논문은 Transformer를 memory적으로 efficient하게 개량한 모델입니다. reformer에 대한 설명은 https://tech.scatterlab.co.kr/reformer-review/ scatterlab에서 잘 설명해두었으니 참고바랍니다.

     

    꼼꼼하고 이해하기 쉬운 Reformer 리뷰

    Review of Reformer: The Efficient Transformer

    tech.scatterlab.co.kr

    이 Reformer에서 image generation task를 수행하였는데, 동일한 모델 구조를 사용하면서, positional encoding만 저자들이 제안하는 것으로 바꿔서 실험을 진행하였습니다. 그 결과...

    Image generation

    위 그래프의 (a)는 Transformer의 삼각함수 PE, embedding table PE와 비교했을 때 bit-per-dim(정확히 일치하진 않지만 entropy와 거의 유사한 개념)이 낮게 나와서 더 확신에 찬 모델을 만들었다고 평가할 수 있습니다.

    (b)는 ablation study로 $W_r$(learnable fourier feature)만 쓰거나, $W_r$을 trainable하지 않게 해놓고 MLP만 걸고 실험하거나 하는 식으로 실험한 결과로 거의 비슷하지만, 둘 다 사용한 것이 성능이 제일 좋다고 이야기하고 있습니다.

     

    ②  Object detection

    Object detection task에서 Transformer를 사용한 DETR 논문을 ①과 마찬가지로 positional encoding만 바꾸어 실험한 것입니다. 결과는

    위 표와 같으며, 해석하자면

    Embedding table PE보다 삼각함수 PE가 좋고, 그것보다 우리가 제안한 게 더 위다.

    입니다.

     

    ③ Image classification

    Vision Transformer에서의 classification task를 마찬가지로 positional encoding만 바꾸어서 실험하였습니다.

    Vision Transformer 원 논문과 비교해보면, 이 논문에서의 baseline 성능이 더 떨어졌는데, 이에 대한 변명은 따로 등장하지 않습니다.

    그리고 성능 개선이 거의 없다고 설명하고 있고, 그 이유를

    "Vision Transformer에서 쓴 정도의 parameter 수(86M)이면 discrete하게 학습해도 충분히 Transformer의 capacity로 학습해버릴 수 있다."

    라고 설명하고 있습니다.

     

    ④ Widget captioning

    논문을 읽어보시면 이 task에 공을 쏟은 티가 좀 납니다.

    Widget captioning은 

    Mobile UI가 이미지로 주어지고, 이 화면에 대한 meta정보인 hierarchy structure(backend단에서의 구조를 말하는 듯합니다)를 입력으로 받아서 특정 위치의 widget(단추 등의 아이콘)이 어떤 기능을 하는지 Natural language generation을 하는 task입니다.

     

    이 task에서 widget의 위치는 4D position space로 표현됩니다.

     

    <top, left, bottom, right>로 x축 위치 2개, y축 위치 2개입니다.

    여기서 저자들은 4 dimension의 위치 정보를

    1. 하나의 hyperspace로 두고 실험하기도 하고, (아래 표에서의 1/4)
    2. independent한 두 space로 두기도 하고 (아래 표에서의 2/2)
    3. 각각의 independent space 4개로 두기도 하였습니다(아래 표에서의 4/1)

    그 결과,

    위와 같은 성능을 보였으며, 

    2/2 grouping이나 4/1 grouping으로 복잡한 distance 메커니즘을 사용한 것이 성능이 좋음을 확인하였습니다.

    다만 SPICE metric에서 18.4를 기록한 곳에 bold가 쳐져있지 않는데.. 사소한 실수로 생각됩니다.

     

    3. 정리

    Transformer를 쓰면서 상대적인 위치정보가 중요한 task에서 유용하게 쓰일 수 있다고 생각합니다.

    1. memory도 크게 잡아먹지 않고(오히려 줄여주는 역할)
    2. 구현도 그리 어렵지 않고
    3. general하게 여러 task에서 성능이 좋은 것도 확인이 되었기 때문입니다.

    댓글

Designed by Tistory.