2022-12-27 21:50:55 +08:00
# original source:
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
2023-01-07 05:42:47 +08:00
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
2022-12-27 21:50:55 +08:00
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
2023-01-05 17:37:17 +08:00
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
2022-12-27 21:50:55 +08:00
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2
from functools import partial
import torch
from torch import Tensor
from torch . utils . checkpoint import checkpoint
import math
2023-01-08 04:08:21 +08:00
from typing import Optional , NamedTuple , List
2022-12-27 21:50:55 +08:00
2023-01-10 01:08:48 +08:00
2023-01-05 17:37:17 +08:00
def narrow_trunc (
input : Tensor ,
dim : int ,
start : int ,
length : int
2022-12-27 21:50:55 +08:00
) - > Tensor :
2023-01-05 17:37:17 +08:00
return torch . narrow ( input , dim , start , length if input . shape [ dim ] > = start + length else input . shape [ dim ] - start )
2022-12-27 21:50:55 +08:00
2023-01-10 01:08:48 +08:00
2022-12-27 21:50:55 +08:00
class AttnChunk ( NamedTuple ) :
exp_values : Tensor
exp_weights_sum : Tensor
max_score : Tensor
2023-01-10 01:08:48 +08:00
class SummarizeChunk :
2022-12-27 21:50:55 +08:00
@staticmethod
def __call__ (
query : Tensor ,
key : Tensor ,
value : Tensor ,
) - > AttnChunk : . . .
2023-01-10 01:08:48 +08:00
class ComputeQueryChunkAttn :
2022-12-27 21:50:55 +08:00
@staticmethod
def __call__ (
query : Tensor ,
key : Tensor ,
value : Tensor ,
) - > Tensor : . . .
2023-01-10 01:08:48 +08:00
2022-12-27 21:50:55 +08:00
def _summarize_chunk (
query : Tensor ,
key : Tensor ,
value : Tensor ,
scale : float ,
) - > AttnChunk :
attn_weights = torch . baddbmm (
2023-07-25 15:03:06 +08:00
torch . zeros ( 1 , 1 , 1 , device = query . device , dtype = query . dtype ) ,
2022-12-27 21:50:55 +08:00
query ,
key . transpose ( 1 , 2 ) ,
alpha = scale ,
beta = 0 ,
)
max_score , _ = torch . max ( attn_weights , - 1 , keepdim = True )
max_score = max_score . detach ( )
exp_weights = torch . exp ( attn_weights - max_score )
2023-01-25 13:23:10 +08:00
exp_values = torch . bmm ( exp_weights , value ) if query . device . type == ' mps ' else torch . bmm ( exp_weights , value . to ( exp_weights . dtype ) ) . to ( value . dtype )
2022-12-27 21:50:55 +08:00
max_score = max_score . squeeze ( - 1 )
return AttnChunk ( exp_values , exp_weights . sum ( dim = - 1 ) , max_score )
2023-01-10 01:08:48 +08:00
2022-12-27 21:50:55 +08:00
def _query_chunk_attention (
query : Tensor ,
key : Tensor ,
value : Tensor ,
summarize_chunk : SummarizeChunk ,
kv_chunk_size : int ,
) - > Tensor :
batch_x_heads , k_tokens , k_channels_per_head = key . shape
_ , _ , v_channels_per_head = value . shape
def chunk_scanner ( chunk_idx : int ) - > AttnChunk :
2023-01-05 17:37:17 +08:00
key_chunk = narrow_trunc (
2022-12-27 21:50:55 +08:00
key ,
2023-01-05 17:37:17 +08:00
1 ,
chunk_idx ,
kv_chunk_size
2022-12-27 21:50:55 +08:00
)
2023-01-05 17:37:17 +08:00
value_chunk = narrow_trunc (
2022-12-27 21:50:55 +08:00
value ,
2023-01-05 17:37:17 +08:00
1 ,
chunk_idx ,
kv_chunk_size
2022-12-27 21:50:55 +08:00
)
return summarize_chunk ( query , key_chunk , value_chunk )
chunks : List [ AttnChunk ] = [
chunk_scanner ( chunk ) for chunk in torch . arange ( 0 , k_tokens , kv_chunk_size )
]
acc_chunk = AttnChunk ( * map ( torch . stack , zip ( * chunks ) ) )
chunk_values , chunk_weights , chunk_max = acc_chunk
global_max , _ = torch . max ( chunk_max , 0 , keepdim = True )
max_diffs = torch . exp ( chunk_max - global_max )
chunk_values * = torch . unsqueeze ( max_diffs , - 1 )
chunk_weights * = max_diffs
all_values = chunk_values . sum ( dim = 0 )
all_weights = torch . unsqueeze ( chunk_weights , - 1 ) . sum ( dim = 0 )
return all_values / all_weights
2023-01-10 01:08:48 +08:00
2022-12-27 21:50:55 +08:00
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking (
query : Tensor ,
key : Tensor ,
value : Tensor ,
scale : float ,
) - > Tensor :
attn_scores = torch . baddbmm (
2023-07-25 15:03:06 +08:00
torch . zeros ( 1 , 1 , 1 , device = query . device , dtype = query . dtype ) ,
2022-12-27 21:50:55 +08:00
query ,
key . transpose ( 1 , 2 ) ,
alpha = scale ,
beta = 0 ,
)
attn_probs = attn_scores . softmax ( dim = - 1 )
del attn_scores
2023-01-25 13:23:10 +08:00
hidden_states_slice = torch . bmm ( attn_probs , value ) if query . device . type == ' mps ' else torch . bmm ( attn_probs , value . to ( attn_probs . dtype ) ) . to ( value . dtype )
2022-12-27 21:50:55 +08:00
return hidden_states_slice
2023-01-10 01:08:48 +08:00
2022-12-27 21:50:55 +08:00
class ScannedChunk ( NamedTuple ) :
chunk_idx : int
attn_chunk : AttnChunk
2023-01-10 01:08:48 +08:00
2022-12-27 21:50:55 +08:00
def efficient_dot_product_attention (
query : Tensor ,
key : Tensor ,
value : Tensor ,
query_chunk_size = 1024 ,
kv_chunk_size : Optional [ int ] = None ,
kv_chunk_size_min : Optional [ int ] = None ,
use_checkpoint = True ,
) :
""" Computes efficient dot-product attention given query, key, and value.
This is efficient version of attention presented in
https : / / arxiv . org / abs / 2112.05682 v2 which comes with O ( sqrt ( n ) ) memory requirements .
Args :
query : queries for calculating attention with shape of
` [ batch * num_heads , tokens , channels_per_head ] ` .
key : keys for calculating attention with shape of
` [ batch * num_heads , tokens , channels_per_head ] ` .
value : values to be used in attention with shape of
` [ batch * num_heads , tokens , channels_per_head ] ` .
query_chunk_size : int : query chunks size
kv_chunk_size : Optional [ int ] : key / value chunks size . if None : defaults to sqrt ( key_tokens )
kv_chunk_size_min : Optional [ int ] : key / value minimum chunk size . only considered when kv_chunk_size is None . changes ` sqrt ( key_tokens ) ` into ` max ( sqrt ( key_tokens ) , kv_chunk_size_min ) ` , to ensure our chunk sizes don ' t get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint : bool : whether to use checkpointing ( recommended True for training , False for inference )
Returns :
Output of shape ` [ batch * num_heads , query_tokens , channels_per_head ] ` .
"""
batch_x_heads , q_tokens , q_channels_per_head = query . shape
_ , k_tokens , _ = key . shape
scale = q_channels_per_head * * - 0.5
kv_chunk_size = min ( kv_chunk_size or int ( math . sqrt ( k_tokens ) ) , k_tokens )
if kv_chunk_size_min is not None :
kv_chunk_size = max ( kv_chunk_size , kv_chunk_size_min )
def get_query_chunk ( chunk_idx : int ) - > Tensor :
2023-01-05 17:37:17 +08:00
return narrow_trunc (
2022-12-27 21:50:55 +08:00
query ,
2023-01-05 17:37:17 +08:00
1 ,
chunk_idx ,
min ( query_chunk_size , q_tokens )
2022-12-27 21:50:55 +08:00
)
2023-05-11 23:28:15 +08:00
2022-12-27 21:50:55 +08:00
summarize_chunk : SummarizeChunk = partial ( _summarize_chunk , scale = scale )
summarize_chunk : SummarizeChunk = partial ( checkpoint , summarize_chunk ) if use_checkpoint else summarize_chunk
compute_query_chunk_attn : ComputeQueryChunkAttn = partial (
_get_attention_scores_no_kv_chunking ,
scale = scale
) if k_tokens < = kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial (
_query_chunk_attention ,
kv_chunk_size = kv_chunk_size ,
summarize_chunk = summarize_chunk ,
)
)
if q_tokens < = query_chunk_size :
# fast-path for when there's just 1 query chunk
return compute_query_chunk_attn (
query = query ,
key = key ,
value = value ,
)
2023-05-11 12:45:05 +08:00
res = torch . zeros_like ( query )
for i in range ( math . ceil ( q_tokens / query_chunk_size ) ) :
2023-05-11 11:05:18 +08:00
attn_scores = compute_query_chunk_attn (
2022-12-27 21:50:55 +08:00
query = get_query_chunk ( i * query_chunk_size ) ,
key = key ,
value = value ,
2023-05-11 11:05:18 +08:00
)
2023-05-11 12:45:05 +08:00
res [ : , i * query_chunk_size : i * query_chunk_size + attn_scores . shape [ 1 ] , : ] = attn_scores
2023-05-11 11:05:18 +08:00
2022-12-27 21:50:55 +08:00
return res