[Test][BugFix] Fix dispatch_gmm_combine_decode test stability (#7097)
### What this PR does / why we need it?
This patch fix the nightly failure
1. Each case uses a copy of the global kwargs instead of a reference to
prevent parameter pollution between use cases.
2. Add weight initialization in the scenario of `eplb` + `w8a8_dynamic`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
```python
pytest -sv tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py
```
```shell
===================================================================== 3 passed, 4 warnings in 194.86s (0:03:14) ======================================================================
```
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -320,6 +320,9 @@ class FusionOp(DecodeMoeOps):
|
|||||||
self.gmm2_weight_scale_fp32 = [
|
self.gmm2_weight_scale_fp32 = [
|
||||||
weight.clone() for weight in gmm2_weight_scale.unbind(dim=0)
|
weight.clone() for weight in gmm2_weight_scale.unbind(dim=0)
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
self.gmm1_weight_scale_fp32 = [torch.ones(1).npu().to(gmm1_weight.dtype)]
|
||||||
|
self.gmm2_weight_scale_fp32 = [torch.ones(1).npu().to(gmm2_weight.dtype)]
|
||||||
else:
|
else:
|
||||||
self.gmm1_weight = [gmm1_weight.clone()]
|
self.gmm1_weight = [gmm1_weight.clone()]
|
||||||
self.gmm2_weight = [gmm2_weight.clone()]
|
self.gmm2_weight = [gmm2_weight.clone()]
|
||||||
@@ -524,7 +527,7 @@ def run_once(local_rank_id,
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_dispatch_gmm_combine_decode_base():
|
def test_dispatch_gmm_combine_decode_base():
|
||||||
custom_kwargs = BASE_KWARGS
|
custom_kwargs = BASE_KWARGS.copy()
|
||||||
custom_kwargs["batch_size"] = 32
|
custom_kwargs["batch_size"] = 32
|
||||||
custom_kwargs["ep_world_size"] = 8
|
custom_kwargs["ep_world_size"] = 8
|
||||||
custom_kwargs["moe_expert_num"] = 32
|
custom_kwargs["moe_expert_num"] = 32
|
||||||
@@ -539,7 +542,7 @@ def test_dispatch_gmm_combine_decode_base():
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_dispatch_gmm_combine_decode_with_mc2_mask():
|
def test_dispatch_gmm_combine_decode_with_mc2_mask():
|
||||||
custom_kwargs = BASE_KWARGS
|
custom_kwargs = BASE_KWARGS.copy()
|
||||||
custom_kwargs["with_mc2_mask"] = True
|
custom_kwargs["with_mc2_mask"] = True
|
||||||
ep_world_size = custom_kwargs["ep_world_size"]
|
ep_world_size = custom_kwargs["ep_world_size"]
|
||||||
custom_args = tuple(custom_kwargs.values())
|
custom_args = tuple(custom_kwargs.values())
|
||||||
@@ -548,7 +551,7 @@ def test_dispatch_gmm_combine_decode_with_mc2_mask():
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_dispatch_gmm_combine_decode_dynamic_eplb():
|
def test_dispatch_gmm_combine_decode_dynamic_eplb():
|
||||||
custom_kwargs = BASE_KWARGS
|
custom_kwargs = BASE_KWARGS.copy()
|
||||||
custom_kwargs["dynamic_eplb"] = True
|
custom_kwargs["dynamic_eplb"] = True
|
||||||
ep_world_size = custom_kwargs["ep_world_size"]
|
ep_world_size = custom_kwargs["ep_world_size"]
|
||||||
custom_args = tuple(custom_kwargs.values())
|
custom_args = tuple(custom_kwargs.values())
|
||||||
|
|||||||
Reference in New Issue
Block a user