add weight transpose check. (#2756)

### What this PR does / why we need it?
In reinforcement learning scenarios, weight updates are required, but
the current inference applies a transpose operation to the weights,
altering their shape. This causes a shape mismatch with the training
weights, triggering an error during weight updates.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
6fb2788163

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
sherie
2025-09-09 20:33:43 +08:00
committed by GitHub
parent e13c4ddb42
commit 93e28e6862
2 changed files with 109 additions and 9 deletions

View File

@@ -17,7 +17,7 @@ from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
class TestFusedExpertsMoGE(TestBase):
@@ -67,3 +67,39 @@ class TestFusedExpertsMoGE(TestBase):
)
self.assertEqual(output.shape, (4, 128))
class TestLoadWeight(TestBase):
def test_load_w13_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
moe.hidden_size = 8
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
def test_load_w2_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 4)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)
expert_data = torch.randn(4, 128)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)