add weight transpose check. (#2756)
### What this PR does / why we need it?
In reinforcement learning scenarios, weight updates are required, but
the current inference applies a transpose operation to the weights,
altering their shape. This causes a shape mismatch with the training
weights, triggering an error during weight updates.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
6fb2788163
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
|
||||
|
||||
|
||||
class TestFusedExpertsMoGE(TestBase):
|
||||
@@ -67,3 +67,39 @@ class TestFusedExpertsMoGE(TestBase):
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (4, 128))
|
||||
|
||||
|
||||
class TestLoadWeight(TestBase):
|
||||
|
||||
def test_load_w13_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
moe.hidden_size = 8
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
def test_load_w2_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
expert_data = torch.randn(128, 4)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(4, 128)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
Reference in New Issue
Block a user