[MOE]move weight transpose to wakeup for RL secnarios (#4626)

### What this PR does / why we need it?
In reinforcement learning scenarios, the current inference applies a
transpose operation to the weights. For a cleaner architecture, the
weight transpose module was moved to wakeup.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
lhp-deep
2025-12-08 20:34:52 +08:00
committed by GitHub
parent 58db21f56a
commit b230e7e987
7 changed files with 132 additions and 120 deletions

View File

@@ -56,29 +56,18 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
super().__init__(moe=moe)
self.dynamic_eplb = get_ascend_config().dynamic_eplb
self.transpose = True
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer)
if self.transpose:
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data,
requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
self.transpose = False
else:
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w13_weight = torch.nn.Parameter(w13_data,
requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
):
@@ -389,61 +378,6 @@ class AscendFusedMoE(FusedMoE):
return final_hidden_states
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
# Ensure training and inference weight shapes match during RL weight updates
if (len(loaded_weight.shape) >= 2 and len(expert_data.shape) >= 2 and \
loaded_weight.shape[1] != expert_data.shape[1] and \
loaded_weight.shape[0] != expert_data.shape[0]
):
shard_dim = int(not shard_dim)
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
return loaded_weight, shard_dim
def _load_w13(self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim] // 2
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):

View File

@@ -176,9 +176,28 @@ class NPUWorker(WorkerBase):
allocator = CaMemAllocator.get_instance()
allocator.wake_up(tags=tags)
hidden_size = self.vllm_config.model_config.hf_config.hidden_size
model = self.model_runner.model
for name, param in model.named_parameters():
if 'w2_weight' in name and param.shape[2] == hidden_size:
parts = name.split('.')
param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1]))
w2_data = param.transpose(1, 2)
w2_data = torch.nn.Parameter(w2_data, requires_grad=False)
setattr(parent_module, param_name, w2_data)
elif 'w13_weight' in name and param.shape[1] == hidden_size:
parts = name.split('.')
param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1]))
w13_data = param.transpose(1, 2)
w13_data = torch.nn.Parameter(w13_data, requires_grad=False)
setattr(parent_module, param_name, w13_data)
# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)