add weight transpose check. (#2756)
### What this PR does / why we need it?
In reinforcement learning scenarios, weight updates are required, but
the current inference applies a transpose operation to the weights,
altering their shape. This causes a shape mismatch with the training
weights, triggering an error during weight updates.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
6fb2788163
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from unittest.mock import patch
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
|
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
|
||||||
|
|
||||||
|
|
||||||
class TestFusedExpertsMoGE(TestBase):
|
class TestFusedExpertsMoGE(TestBase):
|
||||||
@@ -67,3 +67,39 @@ class TestFusedExpertsMoGE(TestBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(output.shape, (4, 128))
|
self.assertEqual(output.shape, (4, 128))
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadWeight(TestBase):
|
||||||
|
|
||||||
|
def test_load_w13_transpose(self):
|
||||||
|
with patch.object(AscendFusedMoE, "__init__",
|
||||||
|
lambda self, *args, **kwargs: None):
|
||||||
|
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||||
|
moe.hidden_size = 8
|
||||||
|
expert_data = torch.randn(128, 8)
|
||||||
|
loaded_weight = torch.randn(128, 4)
|
||||||
|
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||||
|
|
||||||
|
expert_data = torch.randn(8, 128)
|
||||||
|
loaded_weight = torch.randn(128, 4)
|
||||||
|
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||||
|
|
||||||
|
expert_data = torch.randn(128, 8)
|
||||||
|
loaded_weight = torch.randn(128, 4)
|
||||||
|
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||||
|
|
||||||
|
expert_data = torch.randn(8, 128)
|
||||||
|
loaded_weight = torch.randn(128, 4)
|
||||||
|
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||||
|
|
||||||
|
def test_load_w2_transpose(self):
|
||||||
|
with patch.object(AscendFusedMoE, "__init__",
|
||||||
|
lambda self, *args, **kwargs: None):
|
||||||
|
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||||
|
expert_data = torch.randn(128, 4)
|
||||||
|
loaded_weight = torch.randn(128, 8)
|
||||||
|
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||||
|
|
||||||
|
expert_data = torch.randn(4, 128)
|
||||||
|
loaded_weight = torch.randn(128, 8)
|
||||||
|
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||||
|
|||||||
@@ -134,6 +134,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
|||||||
self.use_aclgraph = (vllm_config.compilation_config.level
|
self.use_aclgraph = (vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE
|
== CompilationLevel.PIECEWISE
|
||||||
and not vllm_config.model_config.enforce_eager)
|
and not vllm_config.model_config.enforce_eager)
|
||||||
|
self.transpose = True
|
||||||
|
|
||||||
|
|
||||||
def forward_oot_v01011(
|
def forward_oot_v01011(
|
||||||
@@ -261,6 +262,7 @@ def forward_oot(
|
|||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
||||||
|
if self.transpose:
|
||||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||||
1, 2).contiguous()
|
1, 2).contiguous()
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||||
@@ -269,6 +271,14 @@ def process_weights_after_loading(self, layer):
|
|||||||
1, 2).contiguous()
|
1, 2).contiguous()
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||||
|
|
||||||
|
self.transpose = False
|
||||||
|
else:
|
||||||
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||||
|
|
||||||
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||||
|
|
||||||
if not is_310p():
|
if not is_310p():
|
||||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
@@ -358,12 +368,11 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
num_redundant_experts,
|
num_redundant_experts,
|
||||||
has_bias,
|
has_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
setup_token_dispatchers(self.moe_config.ep_size,
|
setup_token_dispatchers(self.moe_config.ep_size,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
num_local_experts=self.local_num_experts)
|
num_local_experts=self.local_num_experts)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
self.moe_config.tp_group = get_tp_group()
|
self.moe_config.tp_group = get_tp_group()
|
||||||
self.moe_config.dp_group = get_dp_group()
|
self.moe_config.dp_group = get_dp_group()
|
||||||
self.moe_config.ep_group = get_ep_group()
|
self.moe_config.ep_group = get_ep_group()
|
||||||
@@ -430,6 +439,61 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
|
||||||
|
# Ensure training and inference weight shapes match during RL weight updates
|
||||||
|
if (
|
||||||
|
loaded_weight.shape[1] != expert_data.shape[1] and \
|
||||||
|
loaded_weight.shape[0] != expert_data.shape[0]
|
||||||
|
):
|
||||||
|
shard_dim = int(not shard_dim)
|
||||||
|
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
|
||||||
|
return loaded_weight, shard_dim
|
||||||
|
|
||||||
|
def _load_w13(self,
|
||||||
|
expert_data: torch.Tensor,
|
||||||
|
shard_dim: int,
|
||||||
|
shard_id: str,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
tp_rank: int,
|
||||||
|
load_full: bool = False):
|
||||||
|
# Index the loaded weight for tp sharding.
|
||||||
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||||
|
loaded_weight, shard_dim = self.transpose_weight(
|
||||||
|
loaded_weight, expert_data, shard_dim)
|
||||||
|
shard_size = expert_data.shape[shard_dim] // 2
|
||||||
|
if not load_full:
|
||||||
|
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||||
|
shard_size * tp_rank,
|
||||||
|
shard_size)
|
||||||
|
# Narrow parameter and load.
|
||||||
|
# w1, gate_proj: Load into first logical weight of w13.
|
||||||
|
if shard_id == "w1":
|
||||||
|
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||||
|
# w3, up_proj: Load into second logical weight of w13.
|
||||||
|
else:
|
||||||
|
assert shard_id == "w3"
|
||||||
|
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def _load_w2(self,
|
||||||
|
expert_data: torch.Tensor,
|
||||||
|
shard_dim: int,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
tp_rank: int,
|
||||||
|
load_full: bool = False):
|
||||||
|
# Index the loaded weight for tp sharding.
|
||||||
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||||
|
# Narrow parameter and load.
|
||||||
|
loaded_weight, shard_dim = self.transpose_weight(
|
||||||
|
loaded_weight, expert_data, shard_dim)
|
||||||
|
shard_size = expert_data.shape[shard_dim]
|
||||||
|
if not load_full:
|
||||||
|
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||||
|
shard_size * tp_rank,
|
||||||
|
shard_size)
|
||||||
|
# w2, down_proj: Load into only logical weight of w2.
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
class AscendSharedFusedMoE(AscendFusedMoE):
|
class AscendSharedFusedMoE(AscendFusedMoE):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user