[Feat]support dynamic quantization in allgather (#2841)
### What this PR does / why we need it?
[Feat]support dynamic quantization in allgather
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: main
- vLLM main:
5931b7e5d9
Signed-off-by: withHades <244036962@qq.com>
Signed-off-by: WithHades <244036962@qq.com>
This commit is contained in:
@@ -221,7 +221,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
def test_token_dispatch_without_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
@@ -241,6 +241,32 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
}
|
||||
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
self.row_idx,
|
||||
None,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(results["hidden_states"])
|
||||
self.assertIsNotNone(results["group_list"])
|
||||
self.assertIsNotNone(results["dynamic_scale"])
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
|
||||
Reference in New Issue
Block a user