[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)
### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
704432af3c
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,6 @@ from vllm.config import CacheConfig
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from vllm_ascend.models.deepseek_v2 import (
|
||||
CustomDeepseekV2DecoderLayer, CustomDeepseekV2ForCausalLM,
|
||||
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
|
||||
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
|
||||
CustomDeepseekV2RowParallelLinear,
|
||||
@@ -115,7 +114,8 @@ def mock_distributed():
|
||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
|
||||
patch("torch.npu.current_device", return_value=0):
|
||||
yield
|
||||
|
||||
|
||||
@@ -266,54 +266,3 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
kv_lora_rank=16,
|
||||
prefix="layers.1.self_attn")
|
||||
assert hasattr(attn, "q_proj")
|
||||
|
||||
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_custom_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
||||
torch.randn(2, 128))
|
||||
base_config.n_routed_experts = 4
|
||||
layer = CustomDeepseekV2DecoderLayer(config=base_config,
|
||||
prefix="layers.0",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=CacheConfig(),
|
||||
quant_config=None)
|
||||
assert isinstance(layer.mlp, CustomDeepseekV2MoE)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
|
||||
with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \
|
||||
patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))):
|
||||
hidden_states, residual = layer(positions, x, None)
|
||||
assert hidden_states.shape == (2, 4, 128)
|
||||
|
||||
base_config.n_routed_experts = None
|
||||
layer = CustomDeepseekV2DecoderLayer(config=base_config,
|
||||
prefix="layers.0",
|
||||
model_config=vllm_config.model_config,
|
||||
quant_config=None)
|
||||
assert isinstance(layer.mlp, CustomDeepseekV2MLP)
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
|
||||
model = CustomDeepseekV2ForCausalLM(vllm_config=vllm_config)
|
||||
|
||||
input_ids = torch.randint(0, 10000, (2, 4))
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
with patch.object(model.model,
|
||||
"forward",
|
||||
return_value=torch.randn(2, 4, 128)):
|
||||
output = model(input_ids, positions)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
|
||||
with patch(
|
||||
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
|
||||
):
|
||||
loaded = model.load_weights(weights)
|
||||
assert loaded is not None
|
||||
|
||||
Reference in New Issue
Block a user