[Feat]support dynamic quantization in allgather (#2841)
### What this PR does / why we need it?
[Feat]support dynamic quantization in allgather
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: main
- vLLM main:
5931b7e5d9
Signed-off-by: withHades <244036962@qq.com>
Signed-off-by: WithHades <244036962@qq.com>
This commit is contained in:
@@ -66,7 +66,6 @@ def test_models_distributed_Qwen3_MOE_W8A8():
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
quantization="ascend",
|
quantization="ascend",
|
||||||
enforce_eager=True,
|
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import torch_npu
|
|||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
|
||||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||||
|
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
||||||
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
|
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
@@ -165,6 +166,87 @@ def test_token_dispatcher_with_all_gather(
|
|||||||
torch.npu.reset_peak_memory_stats()
|
torch.npu.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [1, 33, 64])
|
||||||
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
|
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("device", DEVICE)
|
||||||
|
def test_token_dispatcher_with_all_gather_quant(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
ep_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
context_mock = MagicMock()
|
||||||
|
context_mock.fused_moe_state = 0
|
||||||
|
with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context",
|
||||||
|
return_value=context_mock):
|
||||||
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
|
||||||
|
w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype)
|
||||||
|
w2 = torch.randn((e, n, k), device=device, dtype=torch.int8)
|
||||||
|
w2_scale = torch.empty((e, k), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||||
|
expert_map = None
|
||||||
|
local_e = e
|
||||||
|
|
||||||
|
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||||
|
topk_weights, topk_ids = torch.topk(score, topk)
|
||||||
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
row_idx = (torch.arange(
|
||||||
|
0,
|
||||||
|
m * topk,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
).view(topk, -1).permute(1, 0).contiguous())
|
||||||
|
|
||||||
|
dispatcher_kwargs = {
|
||||||
|
"num_experts": e,
|
||||||
|
"top_k": topk,
|
||||||
|
"num_local_experts": local_e,
|
||||||
|
}
|
||||||
|
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
||||||
|
|
||||||
|
apply_router_weight_on_input = False
|
||||||
|
dispatch_output = dispatcher.token_dispatch(
|
||||||
|
hidden_states=a,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
row_idx=row_idx,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
with_quant=True)
|
||||||
|
|
||||||
|
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||||
|
group_list = dispatch_output["group_list"]
|
||||||
|
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||||
|
dynamic_scale = dispatch_output["dynamic_scale"]
|
||||||
|
|
||||||
|
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2=w2,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
group_list=group_list,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
dynamic_scale=dynamic_scale,
|
||||||
|
with_quant=True)
|
||||||
|
combined_output = dispatcher.token_combine(hidden_states=expert_output,
|
||||||
|
bias=None)
|
||||||
|
assert combined_output.shape == (m, k)
|
||||||
|
gc.collect()
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
torch.npu.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64])
|
@pytest.mark.parametrize("m", [1, 33, 64])
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
|
|
||||||
self.assertEqual(results["group_list_type"], 1)
|
self.assertEqual(results["group_list_type"], 1)
|
||||||
|
|
||||||
def test_token_dispatch_with_quant(self):
|
def test_token_dispatch_without_quant(self):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"apply_router_weight_on_input": False,
|
"apply_router_weight_on_input": False,
|
||||||
"top_k": 2,
|
"top_k": 2,
|
||||||
@@ -241,6 +241,32 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
|
|
||||||
self.assertEqual(results["group_list_type"], 1)
|
self.assertEqual(results["group_list_type"], 1)
|
||||||
|
|
||||||
|
def test_token_dispatch_with_quant(self):
|
||||||
|
kwargs = {
|
||||||
|
"apply_router_weight_on_input": False,
|
||||||
|
"top_k": 2,
|
||||||
|
"max_num_tokens": 100,
|
||||||
|
"ep_size": 2,
|
||||||
|
"num_experts": 128,
|
||||||
|
}
|
||||||
|
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
||||||
|
|
||||||
|
hidden_states = torch.randn(3, 128)
|
||||||
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||||
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||||
|
|
||||||
|
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
self.row_idx,
|
||||||
|
None,
|
||||||
|
with_quant=True)
|
||||||
|
|
||||||
|
self.assertIsNotNone(results["hidden_states"])
|
||||||
|
self.assertIsNotNone(results["group_list"])
|
||||||
|
self.assertIsNotNone(results["dynamic_scale"])
|
||||||
|
self.assertEqual(results["group_list_type"], 1)
|
||||||
|
|
||||||
def test_token_combine_with_expert_map(self):
|
def test_token_combine_with_expert_map(self):
|
||||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||||
|
|||||||
@@ -367,7 +367,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
last_expert_idx = self.num_experts_local
|
last_expert_idx = self.num_experts_local
|
||||||
global_num_experts = self.num_experts_local
|
global_num_experts = self.num_experts_local
|
||||||
|
|
||||||
sorted_hidden_states, self.expanded_row_idx, expert_tokens, _ = (
|
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = (
|
||||||
torch_npu.npu_moe_init_routing_v2(
|
torch_npu.npu_moe_init_routing_v2(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -376,7 +376,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
expert_tokens_num_type=1,
|
expert_tokens_num_type=1,
|
||||||
expert_tokens_num_flag=True,
|
expert_tokens_num_flag=True,
|
||||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||||
quant_mode=-1,
|
quant_mode=1 if self.with_quant else -1,
|
||||||
))
|
))
|
||||||
expert_tokens = expert_tokens.to(torch.int64)
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
group_list_type = 1 # `count` mode
|
group_list_type = 1 # `count` mode
|
||||||
@@ -384,6 +384,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
"group_list_type": group_list_type,
|
"group_list_type": group_list_type,
|
||||||
"hidden_states": sorted_hidden_states,
|
"hidden_states": sorted_hidden_states,
|
||||||
"group_list": expert_tokens,
|
"group_list": expert_tokens,
|
||||||
|
"dynamic_scale": pertoken_scale if self.with_quant else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def token_combine(self,
|
def token_combine(self,
|
||||||
|
|||||||
Reference in New Issue
Block a user