https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html
11.5. Multi-Head Attention
실제로 동일한 쿼리, 키 및 값 세트가 주어지면 모델이 시퀀스 내에서 다양한 범위(예: 단거리 대 장거리)의 종속성을 캡처하는 것과 같이 동일한 attention 메커니즘의 다양한 동작에서 얻은 지식을 결합하기를 원할 수 있습니다. 따라서 어텐션 메커니즘이 쿼리, 키 및 값의 다른 representation 하위 공간을 공동으로 사용하도록 허용하는 것이 유리할 수 있습니다.
To this end, instead of performing a single attention pooling, queries, keys, and values can be transformed with ℎ independently learned linear projections. Then these ℎ projected queries, keys, and values are fed into attention pooling in parallel. In the end, ℎ attention pooling outputs are concatenated and transformed with another learned linear projection to produce the final output. This design is called multi-head attention, where each of the ℎ attention pooling outputs is a head (Vaswani et al., 2017). Using fully connected layers to perform learnable linear transformations, Fig. 11.5.1 describes multi-head attention.
이를 위해 single attention pooling을 수행하는 대신 쿼리, 키 및 값을 ℎ independently 학습된 linear projections으로 변환할 수 있습니다. 그런 다음 이러한 ℎ 프로젝션된 쿼리, 키 및 값이 어텐션 풀링에 병렬로 공급됩니다. 결국 ℎ 어텐션 풀링 출력은 최종 출력을 생성하기 위해 다른 학습된 선형 프로젝션과 연결 및 변환됩니다. 이 디자인을 multi-head attention이라고 하며, 각 ℎ 어텐션 풀링 출력은 head 입니다(Vaswani et al., 2017). 그림 11.5.1은 학습 가능한 선형 변환을 수행하기 위해 fully connected layers를 사용하여 multi-head attention을 설명합니다.
import math
import torch
from torch import nn
from d2l import torch as d2l
11.5.1. Model
Before providing the implementation of multi-head attention, let’s formalize this model mathematically. Given a query q∈ℝdq, a key k∈ℝdk, and a value v∈ℝdv, each attention head ℎi (i=1,…,ℎ) is computed as
멀티 헤드 어텐션을 구현하기 전에 이 모델을 수학적으로 공식화해 보겠습니다. 쿼리 q∈ℝdq, 키 k∈ℝdk 및 값 v∈ℝdv가 주어지면 각 주의 헤드 ℎi(i=1,…,ℎ)는 다음과 같이 계산됩니다.
Based on this design, each head may attend to different parts of the input. More sophisticated functions than the simple weighted average can be expressed.
이 design에 따라 각 헤드는 입력의 다른 부분에 attend 할 것입니다. simple weighted average보다 더 정교한 기능을 표현할 수 있습니다.
11.5.2. Implementation
In our implementation, we choose the scaled dot-product attention for each head of the multi-head attention. To avoid significant growth of computational cost and parameterization cost, we set pq=pk=pv=po/ℎ. Note that ℎ heads can be computed in parallel if we set the number of outputs of linear transformations for the query, key, and value to pqℎ=pkℎ=pvℎ=po. In the following implementation, po is specified via the argument num_hiddens.
이 구현에서 우리는 multi-head attention의 각 헤드에 대해 scaled dot-product attention을 선택합니다. 계산 비용과 매개변수화 비용의 상당한 증가를 피하기 위해 pq=pk=pv=po/ℎ로 설정합니다. 쿼리, 키 및 값에 대한 선형 변환의 출력 수를 pqℎ=pkℎ=pvℎ=po로 설정하면 ℎ 헤드를 병렬로 계산할 수 있습니다. 다음 구현에서 po는 num_hiddens 인수를 통해 지정됩니다.
class MultiHeadAttention(d2l.Module): #@save
"""Multi-head attention."""
def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super().__init__()
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
self.W_o = nn.LazyLinear(num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# Shape of queries, keys, or values:
# (batch_size, no. of queries or key-value pairs, num_hiddens)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
# After transposing, shape of output queries, keys, or values:
# (batch_size * num_heads, no. of queries or key-value pairs,
# num_hiddens / num_heads)
queries = self.transpose_qkv(self.W_q(queries))
keys = self.transpose_qkv(self.W_k(keys))
values = self.transpose_qkv(self.W_v(values))
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for num_heads
# times, then copy the next item, and so on
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# Shape of output: (batch_size * num_heads, no. of queries,
# num_hiddens / num_heads)
output = self.attention(queries, keys, values, valid_lens)
# Shape of output_concat: (batch_size, no. of queries, num_hiddens)
output_concat = self.transpose_output(output)
return self.W_o(output_concat)
위 코드는 멀티헤드 어텐션(Multi-head Attention)을 구현한 클래스를 나타냅니다.
- __init__: 클래스 초기화 메서드입니다. 멀티헤드 어텐션을 구현하기 위해 필요한 구성 요소들을 초기화합니다. num_hiddens는 각 헤드에서 사용할 히든 유닛의 수, num_heads는 헤드의 개수, dropout은 드롭아웃 비율입니다. 나머지 인자들은 부가적인 인자들입니다.
- forward: 멀티헤드 어텐션의 순전파를 정의한 메서드입니다. 입력으로 queries, keys, values를 받습니다. 이들은 (batch_size, no. of queries or key-value pairs, num_hiddens) 형태의 텐서입니다. valid_lens는 유효한 길이 정보로, (batch_size,) 혹은 (batch_size, no. of queries) 형태의 텐서입니다. 이 메서드는 멀티헤드 어텐션을 구성하는 다양한 선형 변환과 어텐션 연산을 수행합니다.
- transpose_qkv: 어텐션 연산을 위해 Queries, Keys, Values를 준비하는 과정을 담당합니다. 이 메서드는 입력으로 받은 텐서를 일정한 변환을 거쳐 차원을 조절하고 헤드의 개수만큼 복제합니다.
- attention: 멀티헤드 어텐션의 주요 어텐션 연산을 수행합니다. d2l.DotProductAttention 클래스를 통해 어텐션 가중치를 계산합니다.
- transpose_output: 어텐션 연산 결과를 다시 원래 형태로 되돌리는 메서드입니다. 멀티헤드 어텐션 연산 후 헤드별 결과를 원래의 형태로 병합합니다.
총평하면, 이 클래스는 멀티헤드 어텐션의 순전파를 구현한 것으로, 여러 헤드로 어텐션을 계산하고, 각 헤드의 결과를 병합하여 최종 결과를 반환합니다.
더 자세히 분석하면 아래와 같습니다.
class MultiHeadAttention(d2l.Module):
"""Multi-head attention."""
def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super().__init__()
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
self.W_o = nn.LazyLinear(num_hiddens, bias=bias)
- MultiHeadAttention 클래스는 d2l.Module 클래스를 상속하여 멀티헤드 어텐션을 구현합니다.
- __init__ 메서드는 클래스의 초기화를 담당합니다. 멀티헤드 어텐션에 필요한 구성 요소들을 초기화합니다.
- num_hiddens는 각 헤드에서 사용할 히든 유닛의 수를 나타냅니다.
- num_heads는 멀티헤드 어텐션의 헤드 개수를 나타냅니다.
- dropout은 드롭아웃 비율을 나타냅니다.
- bias는 선형 변환에 편향을 사용할지 여부를 결정합니다.
- self.attention은 멀티헤드 어텐션에서 사용할 어텐션 객체입니다.
- self.W_q, self.W_k, self.W_v는 선형 변환을 통해 Queries, Keys, Values를 준비하는 데 사용됩니다.
- self.W_o는 최종 출력을 만들기 위한 선형 변환입니다.
def forward(self, queries, keys, values, valid_lens):
queries = self.transpose_qkv(self.W_q(queries))
keys = self.transpose_qkv(self.W_k(keys))
values = self.transpose_qkv(self.W_v(values))
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
output = self.attention(queries, keys, values, valid_lens)
output_concat = self.transpose_output(output)
return self.W_o(output_concat)
- forward 메서드는 멀티헤드 어텐션의 순전파 연산을 수행합니다.
- self.transpose_qkv 메서드를 통해 Queries, Keys, Values를 준비합니다.
- valid_lens가 주어진 경우, 헤드 수에 맞게 복제하여 준비합니다.
- 멀티헤드 어텐션 연산을 수행하는 self.attention 객체를 호출합니다.
- 어텐션 연산 결과를 원래 형태로 변환하고 선형 변환을 거쳐 최종 출력을 만듭니다.
In summary, the MultiHeadAttention class implements the forward pass of multi-head attention, where queries, keys, and values are prepared using linear transformations, and then attention is computed using the self.attention object. The final output is obtained by transforming the attention output using another linear transformation.
요약하면 MultiHeadAttention 클래스는 선형 변환을 사용하여 쿼리, 키 및 값을 준비한 다음 self.attention 개체를 사용하여 주의를 계산하는 다중 헤드 어텐션의 정방향 전달을 구현합니다. 최종 출력은 다른 선형 변환을 사용하여 어텐션 출력을 변환하여 얻습니다.
To allow for parallel computation of multiple heads, the above MultiHeadAttention class uses two transposition methods as defined below. Specifically, the transpose_output method reverses the operation of the transpose_qkv method.
multiple heads의 병렬 계산을 허용하기 위해 위의 MultiHeadAttention 클래스는 아래에 정의된 두 가지 transposition 방법을 사용합니다. 구체적으로 transpose_output 메서드는 transpose_qkv 메서드의 동작을 반대로 합니다.
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_qkv(self, X):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input X: (batch_size, no. of queries or key-value pairs,
# num_hiddens). Shape of output X: (batch_size, no. of queries or
# key-value pairs, num_heads, num_hiddens / num_heads)
X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
# Shape of output X: (batch_size, num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
X = X.permute(0, 2, 1, 3)
# Shape of output: (batch_size * num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_output(self, X):
"""Reverse the operation of transpose_qkv."""
X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
위의 코드는 MultiHeadAttention 클래스에 두 개의 새로운 메서드를 추가하는 부분입니다. 이 두 메서드는 멀티헤드 어텐션의 연산을 병렬로 처리하기 위해 입력 데이터의 형태를 변환하는 역할을 수행합니다. 각 메서드에 대해 한 줄씩 설명해보겠습니다.
@d2l.add_to_class(MultiHeadAttention)
def transpose_qkv(self, X):
"""Transposition for parallel computation of multiple attention heads."""
X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
X = X.permute(0, 2, 1, 3)
return X.reshape(-1, X.shape[2], X.shape[3])
- transpose_qkv 메서드는 Queries, Keys, Values를 준비하기 위한 변환을 수행합니다.
- X는 입력 데이터로, shape는 (batch_size, no. of queries or key-value pairs, num_hiddens)입니다.
- 먼저, 입력 데이터의 shape을 조정하여 각 헤드에 대한 차원을 추가합니다. 따라서 X의 shape는 (batch_size, no. of queries or key-value pairs, num_heads, num_hiddens / num_heads)이 됩니다.
- 그 다음, permute를 사용하여 헤드 차원을 뒤로 이동시킵니다. 이렇게 하면 병렬 계산이 가능해집니다. 결과적으로 X의 shape는 (batch_size, num_heads, no. of queries or key-value pairs, num_hiddens / num_heads)가 됩니다.
- 마지막으로, reshape을 통해 다시 원래의 형태로 되돌립니다. 이렇게 하면 X의 shape는 (batch_size * num_heads, no. of queries or key-value pairs, num_hiddens / num_heads)이 됩니다.
@d2l.add_to_class(MultiHeadAttention)
def transpose_output(self, X):
"""Reverse the operation of transpose_qkv."""
X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
- transpose_output 메서드는 transpose_qkv 메서드의 역변환을 수행합니다.
- 입력 데이터 X를 다시 원래의 형태로 되돌리기 위해 역변환을 수행합니다.
- 먼저, X의 shape을 재조정하여 현재의 차원을 헤드 차원으로 되돌립니다. 따라서 X의 shape는 (batch_size * num_heads, no. of queries or key-value pairs, num_hiddens / num_heads)이 됩니다.
- 그 다음, 다시 permute를 사용하여 차원을 원래대로 조정합니다. 결과적으로 X의 shape는 (batch_size * num_heads, no. of queries or key-value pairs, num_hiddens / num_heads)가 됩니다.
- 마지막으로, reshape을 통해 원래의 형태로 되돌립니다. 이렇게 하면 X의 shape는 (batch_size, no. of queries or key-value pairs, num_hiddens)가 됩니다.
이렇게 두 메서드를산을 한, 연산 성능을 향상시킬 수 있습니다.
Let’s test our implemented MultiHeadAttention class using a toy example where keys and values are the same. As a result, the shape of the multi-head attention output is (batch_size, num_queries, num_hiddens).
키와 값이 동일한 toy example를 사용하여 구현된 MultiHeadAttention 클래스를 테스트해 보겠습니다. 결과적으로 멀티 헤드 어텐션 출력의 모양은 (batch_size, num_queries, num_hiddens)입니다.
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens),
(batch_size, num_queries, num_hiddens))
위의 코드는 MultiHeadAttention 클래스의 인스턴스를 생성하고 실제 데이터를 이용하여 어텐션 연산을 수행하는 부분입니다. 코드를 한 줄씩 설명하겠습니다.
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
- num_hiddens는 히든 차원의 크기를 나타내고, num_heads는 멀티헤드 어텐션의 헤드 수를 나타냅니다.
- 위 코드는 이러한 설정을 사용하여 MultiHeadAttention 클래스의 인스턴스인 attention을 생성합니다. 드롭아웃 비율은 0.5로 설정됩니다.
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
- batch_size는 미니배치 크기, num_queries는 쿼리의 개수, num_kvpairs는 키-값 쌍의 개수를 나타냅니다.
- valid_lens는 각 데이터 포인트에 대한 유효한 길이입니다.
- X와 Y는 더미 데이터로, X는 쿼리 데이터, Y는 키-값 데이터를 나타냅니다. 데이터는 모두 히든 차원 크기로 채워진 텐서입니다.
d2l.check_shape(attention(X, Y, Y, valid_lens),
(batch_size, num_queries, num_hiddens))
- 위 코드는 실제 어텐션 연산을 수행하고 그 결과의 shape을 확인합니다.
- attention 인스턴스에 X, Y, Y, valid_lens를 전달하여 어텐션을 계산합니다. 여기서 쿼리와 키-값 데이터가 같은 데이터셋 Y를 사용하는 self-attention을 수행합니다.
- 어텐션 결과의 shape이 (batch_size, num_queries, num_hiddens)와 일치하는지 확인합니다.
즉, 위의 코드는 멀티헤드 어텐션 클래스를 이용하여 더미 데이터를 사용하여 실제 어텐션 연산을 수행하고 그 결과의 shape을 확인하는 예시입니다.
11.5.3. Summary
Multi-head attention combines knowledge of the same attention pooling via different representation subspaces of queries, keys, and values. To compute multiple heads of multi-head attention in parallel, proper tensor manipulation is needed.
Multi-head attention은 쿼리, 키 및 값의 다른 representation 하위 공간을 통해 동일한 attention pooling에 대한 지식을 결합합니다. 멀티 헤드 어텐션의 여러 헤드를 병렬로 계산하려면 적절한 텐서 조작이 필요합니다.
Multi Head Attention이란?
'Multi-Head Attention'은 어텐션 메커니즘의 한 종류로, 하나의 어텐션 가중치만을 사용하는 것이 아니라 여러 개의 어텐션 가중치를 병렬로 계산하여 다양한 관점에서 정보를 추출하는 방법을 의미합니다.
일반적인 어텐션 메커니즘은 쿼리(Query)와 키(Key) 그리고 값(Value)의 세 가지 입력을 받아서 쿼리와 각 키 사이의 유사도를 계산한 후, 이 유사도를 가중치로 사용하여 값들을 가중합하여 최종 출력을 얻습니다. 이러한 메커니즘은 쿼리와 키가 비슷한 패턴을 가지는 경우에 유용하게 작동합니다.
하지만 멀티헤드 어텐션은 여러 개의 어텐션 헤드를 사용하여 쿼리, 키, 값에 각각 다른 선형 변환을 수행한 후 병렬로 어텐션 연산을 수행합니다. 이렇게 함으로써 네트워크는 여러 가지 다른 관점에서 정보를 추출하고, 각 헤드의 어텐션 출력을 합치는 것으로 다양한 특징을 동시에 고려할 수 있습니다. 이렇게 다양한 관점에서 정보를 수집하면 모델의 성능이 향상되는 경우가 많습니다.
멀티헤드 어텐션은 주로 트랜스포머(Transformer)와 같은 딥러닝 모델에서 사용되며, 자연어 처리와 기계 번역과 같은 과제에서 특히 효과적인 결과를 보이는 기술입니다.
11.5.4. Exercises
- Visualize attention weights of multiple heads in this experiment.
- Suppose that we have a trained model based on multi-head attention and we want to prune least important attention heads to increase the prediction speed. How can we design experiments to measure the importance of an attention head?
'Dive into Deep Learning > D2L Attention Mechanisms and Transformer' 카테고리의 다른 글
D2L - 11.9. Large-Scale Pretraining with Transformers (0) | 2023.08.10 |
---|---|
D2L - 11.8. Transformers for Vision (0) | 2023.08.10 |
D2L - 11.7. The Transformer Architecture (0) | 2023.08.09 |
D2L - 11.6. Self-Attention and Positional Encoding (0) | 2023.08.09 |
D2L - 11.4. The Bahdanau Attention Mechanism (0) | 2023.08.08 |
D2L - 11.3. Attention Scoring Functions (0) | 2023.08.07 |
D2L - 11.2. Attention Pooling by Similarity (0) | 2023.08.06 |
D2L - 11.1. Queries, Keys, and Values (0) | 2023.08.05 |
D2L - 11. Attention Mechanisms and Transformers (0) | 2023.08.03 |