提交vllm0.11.0开发分支

This commit is contained in:
chenyili
2025-12-10 17:51:24 +08:00
parent deab7dd0b6
commit 7c22d621fb
175 changed files with 31856 additions and 8683 deletions

View File

@@ -4,13 +4,12 @@ import torch
import torch.nn.functional as F
from typing import Optional, List, Dict, Any
from vllm.attention import AttentionType
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.forward_context import ForwardContext, get_forward_context
@@ -20,10 +19,8 @@ from torch.library import custom_op, impl
from vllm.platforms import _Backend
class Attention(VllmAttention):
"""Attention"""
def __init__(
self,
num_heads: int,
@@ -75,8 +72,11 @@ class Attention(VllmAttention):
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.zeros(output_shape, dtype=query.dtype, device=query.device)
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.zeros(output_shape,
dtype=query.dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
# backend since these tensors have different semantics and are
@@ -97,13 +97,16 @@ class Attention(VllmAttention):
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
)
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output)
else:
torch.ops.vllm.unified_attention_with_output_kunlun(
query, key, value, output, self.layer_name
)
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
@@ -112,15 +115,13 @@ class Attention(VllmAttention):
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata
)
return self.impl.forward(self, query, key, value,
self_kv_cache, attn_metadata)
else:
return unified_attention(query, key, value, self.layer_name)
return unified_attention(
query, key, value, self.layer_name)
#
# Rewritten from the MultiHeadAttention class in vllm.attention.layer
# 重写自 vllm.attention.layer 中的 MultiHeadAttention 类
class MultiHeadAttention(VllmMultiHeadAttention):
def __init__(
self,
@@ -130,15 +131,14 @@ class MultiHeadAttention(VllmMultiHeadAttention):
num_kv_heads: Optional[int] = None,
):
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
num_heads = num_heads,
head_size = head_size,
scale = scale,
num_kv_heads = num_kv_heads,
)
# kunlun only supports flash_attn
# kunlun只支持flash_attn
self.attn_backend = _Backend.FLASH_ATTN
def forward(
self,
query: torch.Tensor,
@@ -159,31 +159,34 @@ class MultiHeadAttention(VllmMultiHeadAttention):
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
# kunlun only supports flash_attn
# kunlun只支持flash_attn
if self.attn_backend == _Backend.FLASH_ATTN:
from flash_attn import flash_attn_func
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale
)
out = xops.memory_efficient_attention_forward(query,
key,
value,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out.reshape(bsz, q_len, -1)
def wait_for_kv_layer_from_connector(layer_name: str):
"""wait_for_kv_layer_from_connector"""
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
@@ -198,10 +201,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str, kv_cache_layer: List[torch.Tensor]
):
layer_name: str,
kv_cache_layer: List[torch.Tensor]):
"""maybe_save_kv_layer_to_connector"""
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
@@ -213,8 +215,8 @@ def maybe_save_kv_layer_to_connector(
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
@custom_op("vllm::unified_attention_with_output_kunlun", mutates_args=())
def unified_attention_with_output_kunlun(
@@ -223,8 +225,7 @@ def unified_attention_with_output_kunlun(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
output_scale: Optional[torch.Tensor] = None,) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@@ -232,26 +233,26 @@ def unified_attention_with_output_kunlun(
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output=output)
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def _fake_unified_attention_with_output_kunlun(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
output_scale: Optional[torch.Tensor] = None,) -> None:
return None
unified_attention_with_output_kunlun.register_fake(
_fake_unified_attention_with_output_kunlun
)
unified_attention_with_output_kunlun.register_fake(_fake_unified_attention_with_output_kunlun)
def unified_attention(
query: torch.Tensor,
@@ -268,7 +269,8 @@ def unified_attention(
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
return output