[Feat][Graph] Support DeepSeek with ACL Graph (#2707)
### What this PR does / why we need it?
In memory of #677 , a long overdue milestone. Now DeepSeek V3/R1 should
be OK with ACL Graph.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Working on it.
- vLLM version: v0.10.2
- vLLM main:
68dbde5dbb
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -41,9 +41,10 @@ def test_row_parallel_linear(cls, mock_distributed):
|
|||||||
assert output[0].shape == (2, 4, 64)
|
assert output[0].shape == (2, 4, 64)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("torch.ops.vllm.mla_forward")
|
||||||
@patch("torch_npu.npu_rms_norm")
|
@patch("torch_npu.npu_rms_norm")
|
||||||
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
|
||||||
base_config):
|
mock_distributed, base_config):
|
||||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||||
|
|
||||||
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
||||||
@@ -64,8 +65,8 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
|||||||
with patch.object(attn.mla_attn,
|
with patch.object(attn.mla_attn,
|
||||||
"__call__",
|
"__call__",
|
||||||
return_value=torch.randn(2, 4, 128)):
|
return_value=torch.randn(2, 4, 128)):
|
||||||
with pytest.raises(AssertionError):
|
attn(positions, x)
|
||||||
attn(positions, x)
|
mock_mla_forward.assert_called_once()
|
||||||
|
|
||||||
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
||||||
hidden_size=128,
|
hidden_size=128,
|
||||||
|
|||||||
@@ -215,21 +215,6 @@ class TestAscendConfig(TestBase):
|
|||||||
test_vllm_config.model_config = fake_model_config
|
test_vllm_config.model_config = fake_model_config
|
||||||
init_ascend_config(test_vllm_config)
|
init_ascend_config(test_vllm_config)
|
||||||
check_ascend_config(test_vllm_config, False)
|
check_ascend_config(test_vllm_config, False)
|
||||||
# aclgraph + deepseek model
|
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
test_vllm_config.additional_config = {
|
|
||||||
"torchair_graph_config": {
|
|
||||||
"enabled": False,
|
|
||||||
},
|
|
||||||
"refresh": True
|
|
||||||
}
|
|
||||||
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
|
|
||||||
fake_model_config = ModelConfig(model=model_path)
|
|
||||||
fake_model_config.hf_config = PretrainedConfig()
|
|
||||||
fake_model_config.hf_config.model_type = "deepseek"
|
|
||||||
test_vllm_config.model_config = fake_model_config
|
|
||||||
init_ascend_config(test_vllm_config)
|
|
||||||
check_ascend_config(test_vllm_config, False)
|
|
||||||
|
|
||||||
def test_check_torchair_supported(self):
|
def test_check_torchair_supported(self):
|
||||||
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
|
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
|
||||||
|
|||||||
@@ -218,14 +218,8 @@ def check_ascend_config(vllm_config, enforce_eager):
|
|||||||
"it has been disabled automatically.")
|
"it has been disabled automatically.")
|
||||||
# aclgraph case
|
# aclgraph case
|
||||||
else:
|
else:
|
||||||
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
|
|
||||||
if vllm_config.model_config:
|
if vllm_config.model_config:
|
||||||
model_type = vllm_config.model_config.hf_config.model_type
|
model_type = vllm_config.model_config.hf_config.model_type
|
||||||
if "deepseek" in model_type:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"ACL Graph does not support deepseek. Please "
|
|
||||||
"try torchair graph mode to serve deepseek models on vllm-ascend."
|
|
||||||
" Or set `enforce_eager=True` to use eager mode.")
|
|
||||||
if "qwen" not in model_type:
|
if "qwen" not in model_type:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ACL Graph is currently experimental. Please "
|
"ACL Graph is currently experimental. Please "
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.models.deepseek_v2 import \
|
from vllm.model_executor.models.deepseek_v2 import \
|
||||||
yarn_get_mscale # noqa: E501
|
yarn_get_mscale # noqa: E501
|
||||||
from vllm.model_executor.models.deepseek_v2 import ( # noqa: E501
|
from vllm.model_executor.models.deepseek_v2 import (
|
||||||
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
||||||
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
|
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
|
||||||
get_spec_layer_idx_from_weight_name)
|
get_spec_layer_idx_from_weight_name)
|
||||||
|
|||||||
@@ -25,10 +25,11 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -80,6 +81,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
|||||||
self.qk_nope_head_dim = qk_nope_head_dim
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
self.qk_head_dim = qk_head_dim
|
self.qk_head_dim = qk_head_dim
|
||||||
self.v_head_dim = v_head_dim
|
self.v_head_dim = v_head_dim
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
self.mla_attn = Attention(
|
self.mla_attn = Attention(
|
||||||
num_heads=self.num_local_heads,
|
num_heads=self.num_local_heads,
|
||||||
@@ -107,15 +109,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
|||||||
o_proj=mla_modules.o_proj,
|
o_proj=mla_modules.o_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor] = None,
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
forward_context = get_forward_context()
|
|
||||||
if kv_cache is None:
|
|
||||||
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
|
||||||
num_tokens = hidden_states.shape[0]
|
num_tokens = hidden_states.shape[0]
|
||||||
need_gather_q_kv = False
|
need_gather_q_kv = False
|
||||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||||
@@ -129,16 +133,47 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
|||||||
if num_tokens % self.tp_size:
|
if num_tokens % self.tp_size:
|
||||||
rows += 1
|
rows += 1
|
||||||
output_shape = (rows, hidden_states.shape[1])
|
output_shape = (rows, hidden_states.shape[1])
|
||||||
|
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||||
output = torch.empty(output_shape,
|
output = torch.empty(output_shape,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
if forward_context.attn_metadata:
|
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output,
|
||||||
attn_metadata = forward_context.attn_metadata[
|
self.prefix)
|
||||||
self.mla_attn.layer_name]
|
|
||||||
else:
|
|
||||||
attn_metadata = forward_context.attn_metadata
|
|
||||||
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
|
|
||||||
attn_metadata, need_gather_q_kv,
|
|
||||||
output)
|
|
||||||
output = output.view(-1, output_shape[-1])
|
output = output.view(-1, output_shape[-1])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def mla_forward(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
need_gather_q_kv: bool,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
if forward_context.attn_metadata:
|
||||||
|
attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name]
|
||||||
|
else:
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||||
|
self.mla_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
|
||||||
|
need_gather_q_kv, output)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def mla_forward_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
need_gather_q_kv: bool,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="mla_forward",
|
||||||
|
op_func=mla_forward,
|
||||||
|
mutates_args=["output"],
|
||||||
|
fake_impl=mla_forward_fake,
|
||||||
|
dispatch_key="PrivateUse1",
|
||||||
|
)
|
||||||
|
|||||||
@@ -227,8 +227,9 @@ class NPUPlatform(Platform):
|
|||||||
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
|
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
|
||||||
compilation_config.set_splitting_ops_for_v1()
|
compilation_config.set_splitting_ops_for_v1()
|
||||||
compilation_config.use_inductor = False
|
compilation_config.use_inductor = False
|
||||||
compilation_config.splitting_ops.extend(
|
compilation_config.splitting_ops.extend([
|
||||||
["vllm.unified_ascend_attention_with_output"])
|
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
|
||||||
|
])
|
||||||
update_aclgraph_sizes(vllm_config)
|
update_aclgraph_sizes(vllm_config)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||||
|
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||||
@@ -412,7 +413,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
||||||
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
||||||
|
|
||||||
self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size
|
# NOTE: Technically, MC2 can have 512 tokens each rank, but this will consume too much memory. The formula is:
|
||||||
|
# ((maxBs * tokenNeedSizeDispatch * ep_worldsize * localMoeExpertNum) + (maxBs * tokenNeedSizeCombine * (k + sharedExpertNum))) * 2
|
||||||
|
# so we have to limit the MC2 tokens to save memory, should fix this in the future.
|
||||||
|
self.mc2_tokens_capacity = 512
|
||||||
self.reserved_mc2_mask = torch.zeros(
|
self.reserved_mc2_mask = torch.zeros(
|
||||||
self.mc2_tokens_capacity,
|
self.mc2_tokens_capacity,
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
@@ -2811,6 +2815,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# or enable more requests to be processed simultaneously.
|
# or enable more requests to be processed simultaneously.
|
||||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||||
continue
|
continue
|
||||||
|
if isinstance(attn_module, AscendMultiHeadLatentAttention):
|
||||||
|
continue
|
||||||
|
|
||||||
# TODO: Support other attention modules, e.g., cross-attention
|
# TODO: Support other attention modules, e.g., cross-attention
|
||||||
# TODO(lucas): move the attention specs into the model layers like
|
# TODO(lucas): move the attention specs into the model layers like
|
||||||
|
|||||||
Reference in New Issue
Block a user