提交vllm0.11.0开发分支
This commit is contained in:
@@ -4,13 +4,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List, Dict, Any
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
@@ -20,10 +19,8 @@ from torch.library import custom_op, impl
|
||||
|
||||
from vllm.platforms import _Backend
|
||||
|
||||
|
||||
class Attention(VllmAttention):
|
||||
"""Attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -75,8 +72,11 @@ class Attention(VllmAttention):
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
if self.use_output:
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.zeros(output_shape, dtype=query.dtype, device=query.device)
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
@@ -97,13 +97,16 @@ class Attention(VllmAttention):
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata, output=output
|
||||
)
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output_kunlun(
|
||||
query, key, value, output, self.layer_name
|
||||
)
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
@@ -112,15 +115,13 @@ class Attention(VllmAttention):
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata
|
||||
)
|
||||
return self.impl.forward(self, query, key, value,
|
||||
self_kv_cache, attn_metadata)
|
||||
else:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
return unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
|
||||
#
|
||||
# Rewritten from the MultiHeadAttention class in vllm.attention.layer
|
||||
# 重写自 vllm.attention.layer 中的 MultiHeadAttention 类
|
||||
class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -130,15 +131,14 @@ class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads = num_heads,
|
||||
head_size = head_size,
|
||||
scale = scale,
|
||||
num_kv_heads = num_kv_heads,
|
||||
)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
# kunlun只支持flash_attn
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -159,31 +159,34 @@ class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
# kunlun只支持flash_attn
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query, key, value, scale=self.scale
|
||||
)
|
||||
out = xops.memory_efficient_attention_forward(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
"""wait_for_kv_layer_from_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
@@ -198,10 +201,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str, kv_cache_layer: List[torch.Tensor]
|
||||
):
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor]):
|
||||
"""maybe_save_kv_layer_to_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
@@ -213,8 +215,8 @@ def maybe_save_kv_layer_to_connector(
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
|
||||
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||
attn_metadata[layer_name])
|
||||
|
||||
@custom_op("vllm::unified_attention_with_output_kunlun", mutates_args=())
|
||||
def unified_attention_with_output_kunlun(
|
||||
@@ -223,8 +225,7 @@ def unified_attention_with_output_kunlun(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
output_scale: Optional[torch.Tensor] = None,) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@@ -232,26 +233,26 @@ def unified_attention_with_output_kunlun(
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output=output)
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def _fake_unified_attention_with_output_kunlun(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
output_scale: Optional[torch.Tensor] = None,) -> None:
|
||||
return None
|
||||
|
||||
|
||||
unified_attention_with_output_kunlun.register_fake(
|
||||
_fake_unified_attention_with_output_kunlun
|
||||
)
|
||||
|
||||
unified_attention_with_output_kunlun.register_fake(_fake_unified_attention_with_output_kunlun)
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -268,7 +269,8 @@ def unified_attention(
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
return output
|
||||
Reference in New Issue
Block a user