[Perf] move quant before allgather in Allgather EP (#3420)
### What this PR does / why we need it?
move quant before allgather in Allgather EP, rely on
https://github.com/vllm-project/vllm-ascend/pull/3334
Deepseek R1 W8A8 performance on A2 with
`HCCL_ALGO="level0:NA;level1:pipeline"`:
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|----------|----------|----------|
| 4k | 375.21 | 364.99 |
| 16k | 1465.23 | 1421.75 |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -189,6 +189,25 @@ def test_sp_for_qwen3_moe() -> None:
|
|||||||
vllm_model.generate(example_prompts, sampling_params)
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||||
|
def test_models_distributed_deepseek_v2_lite_with_flashcomm_v1() -> None:
|
||||||
|
example_prompts = [
|
||||||
|
"test" * 1001,
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9)
|
||||||
|
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
|
||||||
|
dtype="auto",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
enable_expert_parallel=True,
|
||||||
|
enforce_eager=True,
|
||||||
|
quantization="ascend") as vllm_model:
|
||||||
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
||||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
|
||||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||||
|
|||||||
@@ -458,6 +458,7 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
dtype=torch.float32))
|
dtype=torch.float32))
|
||||||
|
|
||||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||||
|
hidden_states_shape = hidden_states.shape
|
||||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||||
@@ -486,7 +487,7 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
mock_npu_swiglu.assert_called_once()
|
mock_npu_swiglu.assert_called_once()
|
||||||
mock_npu_dynamic_quant.assert_called_once()
|
mock_npu_dynamic_quant.assert_called_once()
|
||||||
|
|
||||||
self.assertEqual(result.shape, hidden_states.shape)
|
self.assertEqual(result.shape, hidden_states_shape)
|
||||||
self.assertEqual(result.dtype, torch.bfloat16)
|
self.assertEqual(result.dtype, torch.bfloat16)
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||||
@@ -568,6 +569,7 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
dtype=torch.float32))
|
dtype=torch.float32))
|
||||||
|
|
||||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||||
|
hidden_states_shape = hidden_states.shape
|
||||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||||
@@ -596,7 +598,7 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
|
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
|
||||||
|
|
||||||
self.assertTrue(mock_forward_context.with_quant)
|
self.assertTrue(mock_forward_context.with_quant)
|
||||||
self.assertEqual(result.shape, hidden_states.shape)
|
self.assertEqual(result.shape, hidden_states_shape)
|
||||||
self.assertEqual(result.dtype, torch.bfloat16)
|
self.assertEqual(result.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -289,7 +289,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
|
||||||
setup_moe_comm_method(self.moe_config)
|
setup_moe_comm_method(self.moe_config, self.quant_method)
|
||||||
|
|
||||||
def update_expert_map(self, new_expert_map):
|
def update_expert_map(self, new_expert_map):
|
||||||
self.expert_map = new_expert_map
|
self.expert_map = new_expert_map
|
||||||
@@ -336,11 +336,17 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
replace_allreduce=forward_context.sp_enabled,
|
replace_allreduce=forward_context.sp_enabled,
|
||||||
enable_shared_expert_dp=self.enable_shared_expert_dp)
|
enable_shared_expert_dp=self.enable_shared_expert_dp)
|
||||||
|
|
||||||
|
if isinstance(hidden_states, tuple):
|
||||||
|
hidden_states, pertoken_scale = hidden_states
|
||||||
|
else:
|
||||||
|
pertoken_scale = None
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
|||||||
@@ -27,10 +27,14 @@ from vllm_ascend.ascend_forward_context import MoECommType
|
|||||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||||
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
||||||
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
|
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast, QuantType)
|
||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||||
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
|
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
|
||||||
|
from vllm_ascend.quantization.w4a8_dynamic import \
|
||||||
|
AscendW4A8DynamicFusedMoEMethod
|
||||||
|
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||||
|
AscendW8A8DynamicFusedMoEMethod
|
||||||
|
|
||||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||||
|
|
||||||
@@ -40,25 +44,43 @@ def get_moe_comm_method(
|
|||||||
return _MoECommMethods.get(moe_comm_type, None)
|
return _MoECommMethods.get(moe_comm_type, None)
|
||||||
|
|
||||||
|
|
||||||
def setup_moe_comm_method(moe_config):
|
def setup_moe_comm_method(moe_config, quant_method):
|
||||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(
|
||||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
moe_config, quant_method)
|
||||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(
|
||||||
|
moe_config, quant_method)
|
||||||
|
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config, quant_method)
|
||||||
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
|
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
|
||||||
moe_config)
|
moe_config, quant_method)
|
||||||
|
|
||||||
|
|
||||||
class MoECommMethod(ABC):
|
class MoECommMethod(ABC):
|
||||||
"""Base class for MoE communication methods."""
|
"""Base class for MoE communication methods."""
|
||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig):
|
def __init__(self, moe_config: FusedMoEConfig, quant_method=None):
|
||||||
self.model_type = get_current_vllm_config(
|
self.model_type = get_current_vllm_config(
|
||||||
).model_config.hf_config.model_type
|
).model_config.hf_config.model_type
|
||||||
self.moe_config = moe_config
|
self.moe_config = moe_config
|
||||||
|
|
||||||
self.token_dispatcher = self._get_token_dispatcher()
|
self.token_dispatcher = self._get_token_dispatcher()
|
||||||
|
self.quant_type = self._get_quant_type(quant_method)
|
||||||
|
self.with_quant = self.quant_type != QuantType.NONE
|
||||||
self.prepare_finalize = self._get_prepare_finalize()
|
self.prepare_finalize = self._get_prepare_finalize()
|
||||||
|
|
||||||
|
def _get_quant_type(self, quant_method) -> QuantType:
|
||||||
|
if not hasattr(quant_method,
|
||||||
|
"quant_method") or quant_method.quant_method is None:
|
||||||
|
return QuantType.NONE
|
||||||
|
|
||||||
|
method = quant_method.quant_method
|
||||||
|
|
||||||
|
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
|
||||||
|
return QuantType.W8A8
|
||||||
|
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
|
||||||
|
return QuantType.W4A8
|
||||||
|
else:
|
||||||
|
return QuantType.NONE
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -90,8 +112,6 @@ class MoECommMethod(ABC):
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int4_w4a8: bool = False,
|
|
||||||
global_num_experts: Optional[int] = None,
|
global_num_experts: Optional[int] = None,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -109,10 +129,11 @@ class MoECommMethod(ABC):
|
|||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
need_trans: bool = False,
|
need_trans: bool = False,
|
||||||
dynamic_eplb: bool = False,
|
dynamic_eplb: bool = False,
|
||||||
mc2_mask: torch.Tensor = None):
|
mc2_mask: torch.Tensor = None,
|
||||||
|
pertoken_scale: Optional[torch.Tensor] = None):
|
||||||
# Check constraints
|
# Check constraints
|
||||||
assert hidden_states.dtype in [
|
assert hidden_states.dtype in [
|
||||||
torch.float32, torch.float16, torch.bfloat16
|
torch.float32, torch.float16, torch.bfloat16, torch.int8
|
||||||
]
|
]
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
@@ -130,28 +151,29 @@ class MoECommMethod(ABC):
|
|||||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||||
mc2_mask=mc2_mask,
|
mc2_mask=mc2_mask,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
with_quant=self.with_quant,
|
||||||
dynamic_eplb=dynamic_eplb)
|
dynamic_eplb=dynamic_eplb,
|
||||||
|
pertoken_scale=pertoken_scale)
|
||||||
|
|
||||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
|
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
|
||||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
|
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
|
||||||
|
|
||||||
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
mlp_output = unified_apply_mlp(
|
||||||
w1=w1,
|
hidden_states=permuted_hidden_states,
|
||||||
w1_scale=w1_scale,
|
w1=w1,
|
||||||
w2=w2,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2=w2,
|
||||||
group_list=expert_tokens,
|
w2_scale=w2_scale,
|
||||||
dynamic_scale=dynamic_scale,
|
group_list=expert_tokens,
|
||||||
group_list_type=group_list_type,
|
dynamic_scale=dynamic_scale,
|
||||||
w1_scale_bias=w1_scale_bias,
|
group_list_type=group_list_type,
|
||||||
w2_scale_bias=w2_scale_bias,
|
w1_scale_bias=w1_scale_bias,
|
||||||
topk_scales=topk_scales,
|
w2_scale_bias=w2_scale_bias,
|
||||||
with_quant=use_int8_w8a8
|
topk_scales=topk_scales,
|
||||||
or use_int4_w4a8,
|
with_quant=self.with_quant,
|
||||||
fusion=use_int8_w8a8,
|
fusion=self.quant_type == QuantType.W8A8,
|
||||||
need_trans=need_trans,
|
need_trans=need_trans,
|
||||||
dynamic_eplb=dynamic_eplb)
|
dynamic_eplb=dynamic_eplb)
|
||||||
|
|
||||||
final_hidden_states = self.token_dispatcher.token_combine(
|
final_hidden_states = self.token_dispatcher.token_combine(
|
||||||
hidden_states=mlp_output, context_metadata=context_metadata)
|
hidden_states=mlp_output, context_metadata=context_metadata)
|
||||||
@@ -204,7 +226,8 @@ class AllGatherCommImpl(MoECommMethod):
|
|||||||
num_local_experts=self.moe_config.num_local_experts)
|
num_local_experts=self.moe_config.num_local_experts)
|
||||||
|
|
||||||
def _get_prepare_finalize(self):
|
def _get_prepare_finalize(self):
|
||||||
return PrepareAndFinalizeWithAllGather(self.moe_config)
|
return PrepareAndFinalizeWithAllGather(self.moe_config,
|
||||||
|
self.quant_type)
|
||||||
|
|
||||||
|
|
||||||
class MC2CommImpl(MoECommMethod):
|
class MC2CommImpl(MoECommMethod):
|
||||||
@@ -221,7 +244,7 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
return TokenDispatcherWithMC2()
|
return TokenDispatcherWithMC2()
|
||||||
|
|
||||||
def _get_prepare_finalize(self):
|
def _get_prepare_finalize(self):
|
||||||
return PrepareAndFinalizeWithMC2(self.moe_config)
|
return PrepareAndFinalizeWithMC2(self.moe_config, self.quant_type)
|
||||||
|
|
||||||
|
|
||||||
class AlltoAllCommImpl(MoECommMethod):
|
class AlltoAllCommImpl(MoECommMethod):
|
||||||
@@ -241,7 +264,7 @@ class AlltoAllCommImpl(MoECommMethod):
|
|||||||
num_local_experts=self.moe_config.num_local_experts)
|
num_local_experts=self.moe_config.num_local_experts)
|
||||||
|
|
||||||
def _get_prepare_finalize(self):
|
def _get_prepare_finalize(self):
|
||||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
return PrepareAndFinalizeWithAll2All(self.moe_config, self.quant_type)
|
||||||
|
|
||||||
|
|
||||||
class NaiveMulticastCommImpl(MoECommMethod):
|
class NaiveMulticastCommImpl(MoECommMethod):
|
||||||
@@ -270,4 +293,5 @@ class NaiveMulticastCommImpl(MoECommMethod):
|
|||||||
num_local_experts=self.moe_config.num_local_experts)
|
num_local_experts=self.moe_config.num_local_experts)
|
||||||
|
|
||||||
def _get_prepare_finalize(self):
|
def _get_prepare_finalize(self):
|
||||||
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config,
|
||||||
|
self.quant_type)
|
||||||
|
|||||||
@@ -72,8 +72,10 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
# Dispose the original unquantized hidden states
|
# Dispose the original unquantized hidden states
|
||||||
# to save npu memory because they're no longer used.
|
# to save npu memory because they're no longer used.
|
||||||
dispose_tensor(unquantized_hidden_states)
|
dispose_tensor(unquantized_hidden_states)
|
||||||
|
quantized_hidden_states = None
|
||||||
else:
|
else:
|
||||||
pertoken_scale = dynamic_scale
|
pertoken_scale = dynamic_scale
|
||||||
|
quantized_hidden_states = hidden_states
|
||||||
|
|
||||||
bias1, bias2 = None, None
|
bias1, bias2 = None, None
|
||||||
_output_dtype = w2_scale.dtype
|
_output_dtype = w2_scale.dtype
|
||||||
@@ -92,6 +94,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
group_list=cumsum_group_list(group_list, group_list_type),
|
group_list=cumsum_group_list(group_list, group_list_type),
|
||||||
weight_scale=w1_scale,
|
weight_scale=w1_scale,
|
||||||
x_scale=pertoken_scale)
|
x_scale=pertoken_scale)
|
||||||
|
if quantized_hidden_states is not None:
|
||||||
|
dispose_tensor(quantized_hidden_states)
|
||||||
else:
|
else:
|
||||||
if w1_scale.dtype != torch.float32:
|
if w1_scale.dtype != torch.float32:
|
||||||
w1_scale = w1_scale.to(torch.float32)
|
w1_scale = w1_scale.to(torch.float32)
|
||||||
@@ -104,6 +108,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
output_dtype=torch.int32)[0]
|
output_dtype=torch.int32)[0]
|
||||||
|
if quantized_hidden_states is not None:
|
||||||
|
dispose_tensor(quantized_hidden_states)
|
||||||
# act_fn: swiglu
|
# act_fn: swiglu
|
||||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
@@ -148,6 +154,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
group_list=cumsum_group_list(group_list, group_list_type),
|
group_list=cumsum_group_list(group_list, group_list_type),
|
||||||
weight_scale=w1_scale,
|
weight_scale=w1_scale,
|
||||||
x_scale=pertoken_scale)
|
x_scale=pertoken_scale)
|
||||||
|
if quantized_hidden_states is not None:
|
||||||
|
dispose_tensor(quantized_hidden_states)
|
||||||
else:
|
else:
|
||||||
# gmm1: gate_up_proj
|
# gmm1: gate_up_proj
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
@@ -161,6 +169,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
output_dtype=_output_dtype)[0]
|
output_dtype=_output_dtype)[0]
|
||||||
|
if quantized_hidden_states is not None:
|
||||||
|
dispose_tensor(quantized_hidden_states)
|
||||||
# act_fn: swiglu
|
# act_fn: swiglu
|
||||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||||
|
|||||||
@@ -15,11 +15,13 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch_npu
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_dp_group, get_tensor_model_parallel_rank,
|
get_dp_group, get_tensor_model_parallel_rank,
|
||||||
@@ -30,6 +32,12 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
|||||||
from vllm_ascend.utils import enable_sp
|
from vllm_ascend.utils import enable_sp
|
||||||
|
|
||||||
|
|
||||||
|
class QuantType(Enum):
|
||||||
|
NONE = 0
|
||||||
|
W8A8 = 1
|
||||||
|
W4A8 = 2
|
||||||
|
|
||||||
|
|
||||||
class PrepareAndFinalize(ABC):
|
class PrepareAndFinalize(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
||||||
@@ -42,8 +50,11 @@ class PrepareAndFinalize(ABC):
|
|||||||
sizes, ranks, and communication settings.
|
sizes, ranks, and communication settings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig):
|
def __init__(self,
|
||||||
|
moe_config: FusedMoEConfig,
|
||||||
|
quant_type: QuantType = QuantType.NONE):
|
||||||
self.moe_config = moe_config
|
self.moe_config = moe_config
|
||||||
|
self.quant_type = quant_type
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare(
|
def prepare(
|
||||||
@@ -103,8 +114,10 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
|||||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig):
|
def __init__(self,
|
||||||
super().__init__(moe_config)
|
moe_config: FusedMoEConfig,
|
||||||
|
quant_type: QuantType = QuantType.NONE):
|
||||||
|
super().__init__(moe_config, quant_type)
|
||||||
self._restore_tp_across_dp()
|
self._restore_tp_across_dp()
|
||||||
|
|
||||||
def _restore_tp_across_dp(self):
|
def _restore_tp_across_dp(self):
|
||||||
@@ -195,8 +208,10 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
|||||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig):
|
def __init__(self,
|
||||||
super().__init__(moe_config)
|
moe_config: FusedMoEConfig,
|
||||||
|
quant_type: QuantType = QuantType.NONE):
|
||||||
|
super().__init__(moe_config, quant_type)
|
||||||
self._restore_tp_across_dp()
|
self._restore_tp_across_dp()
|
||||||
|
|
||||||
def _restore_tp_across_dp(self):
|
def _restore_tp_across_dp(self):
|
||||||
@@ -316,11 +331,20 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
|
pertoken_scale = None
|
||||||
|
if self.quant_type == QuantType.W8A8:
|
||||||
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||||
|
hidden_states)
|
||||||
|
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
pertoken_scale, True, True)
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
hidden_states, True, True)
|
hidden_states, True, True)
|
||||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
router_logits, True, True)
|
router_logits, True, True)
|
||||||
|
|
||||||
|
if pertoken_scale is not None:
|
||||||
|
return (hidden_states, pertoken_scale), router_logits, None, None
|
||||||
|
|
||||||
return hidden_states, router_logits, None, None
|
return hidden_states, router_logits, None, None
|
||||||
|
|
||||||
def _prepare_with_dp_group(
|
def _prepare_with_dp_group(
|
||||||
|
|||||||
@@ -57,20 +57,23 @@ class MoETokenDispatcher(ABC):
|
|||||||
return get_ep_group().world_size
|
return get_ep_group().world_size
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def token_dispatch(self,
|
def token_dispatch(
|
||||||
hidden_states: torch.Tensor,
|
self,
|
||||||
topk_weights: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
topk_ids: torch.Tensor,
|
||||||
log2phy: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
global_redundant_expert_num: int = 0,
|
log2phy: Optional[torch.Tensor] = None,
|
||||||
shared_experts: Optional[Any] = None,
|
global_redundant_expert_num: int = 0,
|
||||||
quantized_x_for_share: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
dynamic_scale_for_share: Optional[Any] = None,
|
quantized_x_for_share: Optional[Any] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
with_quant: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
dynamic_eplb: bool = False):
|
with_quant: bool = False,
|
||||||
|
dynamic_eplb: bool = False,
|
||||||
|
pertoken_scale: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
raise NotImplementedError("Dispatch function not implemented.")
|
raise NotImplementedError("Dispatch function not implemented.")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -170,7 +173,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False,
|
with_quant: bool = False,
|
||||||
dynamic_eplb: bool = False):
|
dynamic_eplb: bool = False,
|
||||||
|
pertoken_scale: Optional[torch.Tensor] = None):
|
||||||
self.with_quant = with_quant
|
self.with_quant = with_quant
|
||||||
|
|
||||||
# Apply log2phy if needed
|
# Apply log2phy if needed
|
||||||
@@ -339,7 +343,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False,
|
with_quant: bool = False,
|
||||||
dynamic_eplb: bool = False):
|
dynamic_eplb: bool = False,
|
||||||
|
pertoken_scale: Optional[torch.Tensor] = None):
|
||||||
self.with_quant = with_quant
|
self.with_quant = with_quant
|
||||||
self.original_shape = hidden_states.shape
|
self.original_shape = hidden_states.shape
|
||||||
|
|
||||||
@@ -370,12 +375,14 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
torch_npu.npu_moe_init_routing_v2(
|
torch_npu.npu_moe_init_routing_v2(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
scale=pertoken_scale,
|
||||||
active_num=num_tokens * self.top_k,
|
active_num=num_tokens * self.top_k,
|
||||||
expert_num=global_num_experts,
|
expert_num=global_num_experts,
|
||||||
expert_tokens_num_type=1,
|
expert_tokens_num_type=1,
|
||||||
expert_tokens_num_flag=True,
|
expert_tokens_num_flag=True,
|
||||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||||
quant_mode=1 if self.with_quant else -1,
|
quant_mode=1
|
||||||
|
if self.with_quant and pertoken_scale is None else -1,
|
||||||
))
|
))
|
||||||
expert_tokens = expert_tokens.to(torch.int64)
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
group_list_type = 1 # `count` mode
|
group_list_type = 1 # `count` mode
|
||||||
@@ -430,7 +437,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
|
|||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False,
|
with_quant: bool = False,
|
||||||
dynamic_eplb: bool = False):
|
dynamic_eplb: bool = False,
|
||||||
|
pertoken_scale: Optional[torch.Tensor] = None):
|
||||||
self.bsz, _ = hidden_states.shape
|
self.bsz, _ = hidden_states.shape
|
||||||
flatten_topk_ids = topk_ids.view(-1)
|
flatten_topk_ids = topk_ids.view(-1)
|
||||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||||
@@ -518,7 +526,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False,
|
with_quant: bool = False,
|
||||||
dynamic_eplb: bool = False):
|
dynamic_eplb: bool = False,
|
||||||
|
pertoken_scale: Optional[torch.Tensor] = None):
|
||||||
self.with_quant = with_quant
|
self.with_quant = with_quant
|
||||||
self.hidden_shape = hidden_states.shape
|
self.hidden_shape = hidden_states.shape
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(
|
|||||||
x = tensor_model_parallel_all_gather(x, 0)
|
x = tensor_model_parallel_all_gather(x, 0)
|
||||||
pad_size = forward_context.pad_size
|
pad_size = forward_context.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
x = x[:-pad_size, :]
|
x = x[:-pad_size]
|
||||||
else:
|
else:
|
||||||
x = get_ep_group().all_gather(x, 0)
|
x = get_ep_group().all_gather(x, 0)
|
||||||
# unpad
|
# unpad
|
||||||
@@ -50,8 +50,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(
|
|||||||
offset = 0
|
offset = 0
|
||||||
for idx in range(dp_size):
|
for idx in range(dp_size):
|
||||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||||
result[offset:offset +
|
result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp]
|
||||||
num_tokens_dp, :] = x[idx, :num_tokens_dp, :]
|
|
||||||
offset += num_tokens_dp
|
offset += num_tokens_dp
|
||||||
x = result
|
x = result
|
||||||
|
|
||||||
|
|||||||
@@ -386,7 +386,6 @@ class AscendW4A8DynamicFusedMoEMethod:
|
|||||||
w2_scale_bias=layer.w2_scale_bias,
|
w2_scale_bias=layer.w2_scale_bias,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
use_int4_w4a8=True,
|
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
log2phy=log2phy,
|
log2phy=log2phy,
|
||||||
global_redundant_expert_num=global_redundant_expert_num,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
and not ascend_config.torchair_graph_config.enabled)
|
and not ascend_config.torchair_graph_config.enabled)
|
||||||
|
|
||||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||||
|
self.in_dtype = vllm_config.model_config.dtype
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_group = get_mc2_group().device_group
|
device_group = get_mc2_group().device_group
|
||||||
@@ -218,6 +219,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
quantized_x_for_share: Optional[Any] = None,
|
quantized_x_for_share: Optional[Any] = None,
|
||||||
dynamic_scale_for_share: Optional[Any] = None,
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
|
pertoken_scale: Optional[Any] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
@@ -242,18 +244,18 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
if enable_force_load_balance:
|
if enable_force_load_balance:
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(self.in_dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w1_scale=layer.w13_weight_scale_fp32,
|
w1_scale=layer.w13_weight_scale_fp32,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
use_int8_w8a8=True,
|
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
log2phy=log2phy,
|
log2phy=log2phy,
|
||||||
global_redundant_expert_num=global_redundant_expert_num,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
|
|||||||
Reference in New Issue
Block a user