[Model] Support DeepSeek-V4
This commit is contained in:
3
vllm_mlu/attention/__init__.py
Normal file
3
vllm_mlu/attention/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
351
vllm_mlu/attention/layer.py
Normal file
351
vllm_mlu/attention/layer.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import MLAAttentionImpl
|
||||
from vllm.attention.layer import Attention, MLAAttention, _init_kv_cache_quant
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.vllm import QuantizationConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm_mlu.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.v1.kv_cache_interface import (
|
||||
MLUFullAttentionSpec,
|
||||
MLUMLAAttentionSpec,
|
||||
MLUSlidingWindowSpec,
|
||||
)
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
kwargs: dict[str, Any] = {},
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add return for self.impl.forward and it's param kwargs
|
||||
'''
|
||||
output = self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return output
|
||||
|
||||
class Attention_MluHijack(Attention):
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
if self.sliding_window is not None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace SlidingWindowSpec with MLUSlidingWindowSpec.
|
||||
'''
|
||||
return MLUSlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace FullAttentionSpec with MLUFullAttentionSpec.
|
||||
'''
|
||||
return MLUFullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
class MLAAttention_MluHijack(MLAAttention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_sparse: bool = False,
|
||||
indexer: object | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.num_heads = num_heads
|
||||
self.scale = scale
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
# self.head_size = kv_lora_rank + qk_rope_head_dim
|
||||
self.layer_name = prefix
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: insert num_kv_heads for mlu platform
|
||||
'''
|
||||
self.head_size = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.num_kv_heads = extra_impl_args.pop("num_kv_heads", None)
|
||||
if self.num_kv_heads is None:
|
||||
self.num_kv_heads = num_heads
|
||||
|
||||
self.decoder_attn_dtype = None
|
||||
decoder_attn_dtype = get_current_vllm_config().mlu_config.decoder_attn_dtype
|
||||
if decoder_attn_dtype in ["int8", "fp8_e4m3", "fp8"]:
|
||||
self.decoder_attn_dtype = (
|
||||
torch.int8 if decoder_attn_dtype == "int8"
|
||||
else torch.float8_e4m3fn
|
||||
)
|
||||
extra_impl_args['decoder_attn_dtype'] = self.decoder_attn_dtype
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(
|
||||
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
|
||||
)
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=True,
|
||||
use_sparse=use_sparse,
|
||||
)
|
||||
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
|
||||
self.impl = impl_cls(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
self.num_kv_heads,
|
||||
None, # alibi_slops
|
||||
None, # sliding_window
|
||||
kv_cache_dtype,
|
||||
None, # logits_soft_cap
|
||||
AttentionType.DECODER, # attn_dtype
|
||||
None, # kv_sharing_target_layer_name
|
||||
**extra_impl_args,
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
if current_platform.is_out_of_tree():
|
||||
self.use_direct_call = False
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support kv8 and deepseek v3.2
|
||||
'''
|
||||
self.kv_cache = [
|
||||
[torch.tensor([]), torch.tensor([]), torch.tensor([])]
|
||||
for _ in range(
|
||||
get_current_vllm_config().parallel_config.pipeline_parallel_size
|
||||
)
|
||||
]
|
||||
self.impl.use_mla = True
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
self.use_sparse = use_sparse
|
||||
|
||||
# Initialize q/k/v range constants.
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
|
||||
self.kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace MLAAttentionSpec with MLUMLAAttentionSpec.
|
||||
'''
|
||||
index_head_dim, index_n_heads = 0, 0
|
||||
if vllm_config.model_config.hf_text_config.model_type == "deepseek_v32":
|
||||
index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim
|
||||
index_n_heads = 1
|
||||
|
||||
if vllm_config.model_config.hf_text_config.model_type == "deepseek_v4":
|
||||
index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim
|
||||
index_n_heads = 1
|
||||
|
||||
return MLUMLAAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||
index_head_dim=index_head_dim,
|
||||
index_n_heads=index_n_heads,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
kwargs: dict[str, Any] = {},
|
||||
) -> torch.Tensor:
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
|
||||
assert not self.use_direct_call, "MLU-V1 does not support direct call."
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output_lse = None
|
||||
output_shape = (output_shape if output_shape is not None else query.shape)
|
||||
output_shape = [output_shape[0], self.num_heads * self.v_head_dim]
|
||||
|
||||
output = torch.empty(
|
||||
output_shape,
|
||||
dtype=self.dtype if query.dtype == torch.int8 else query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.v_head_dim)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.v_head_dim)
|
||||
if not kwargs:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name
|
||||
)
|
||||
attn_output_list = output
|
||||
else:
|
||||
attn_output_list = unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name, kwargs=kwargs)
|
||||
if isinstance(attn_output_list, (list, tuple)) and len(attn_output_list) > 1:
|
||||
output_lse = attn_output_list[1]
|
||||
if output_lse is not None:
|
||||
return output.view(-1, hidden_size), output_lse
|
||||
else:
|
||||
return output.view(-1, hidden_size)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name
|
||||
)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Attention,
|
||||
Attention.get_kv_cache_spec,
|
||||
Attention_MluHijack.get_kv_cache_spec,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
MLAAttention,
|
||||
MLAAttention.__init__,
|
||||
MLAAttention_MluHijack.__init__,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
MLAAttention,
|
||||
MLAAttention.get_kv_cache_spec,
|
||||
MLAAttention_MluHijack.get_kv_cache_spec,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
MLAAttention,
|
||||
MLAAttention.forward,
|
||||
MLAAttention_MluHijack.forward,
|
||||
)
|
||||
3
vllm_mlu/attention/utils/__init__.py
Normal file
3
vllm_mlu/attention/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
62
vllm_mlu/attention/utils/kv_transfer_utils.py
Normal file
62
vllm_mlu/attention/utils/kv_transfer_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
|
||||
|
||||
def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
||||
"""Decorator that handles KV layer transfer prior and after execution of
|
||||
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
|
||||
|
||||
On entry: waits for the KV layer from the connector.
|
||||
On exit: saves the KV layer to the connector.
|
||||
"""
|
||||
# Import at runtime to avoid circular dependency
|
||||
from vllm.attention.layer import get_attention_context
|
||||
|
||||
# Inspect the signature ONCE when the decorator is applied.
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
# Find the index of 'layer_name' parameter.
|
||||
try:
|
||||
layer_name_index = param_names.index("layer_name")
|
||||
except ValueError as e:
|
||||
raise TypeError(
|
||||
f"Function {func.__name__} must have a 'layer_name' parameter"
|
||||
) from e
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
layer_name: str = args[layer_name_index]
|
||||
|
||||
# Extract attention context (layer-specific metadata, layer, and kv_cache)
|
||||
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
|
||||
connector = get_kv_transfer_group()
|
||||
if attn_metadata is None or not connector.has_connector_metadata():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Wait for KV layer on entry
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
# Execute the function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Save KV cache layer on exit
|
||||
|
||||
if kwargs is None or kwargs.get("save_kv_layer", True):
|
||||
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user