[MOE]move weight transpose to wakeup for RL secnarios (#4626)

### What this PR does / why we need it?
In reinforcement learning scenarios, the current inference applies a
transpose operation to the weights. For a cleaner architecture, the
weight transpose module was moved to wakeup.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
lhp-deep
2025-12-08 20:34:52 +08:00
committed by GitHub
parent 58db21f56a
commit b230e7e987
7 changed files with 132 additions and 120 deletions

View File

@@ -205,6 +205,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_pipeline_parallel.py pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
pytest -sv tests/e2e/multicard/test_prefix_caching.py pytest -sv tests/e2e/multicard/test_prefix_caching.py
pytest -sv tests/e2e/multicard/test_qwen3_moe.py pytest -sv tests/e2e/multicard/test_qwen3_moe.py
pytest -sv tests/e2e/multicard/test_offline_weight_load.py
e2e-4-cards: e2e-4-cards:
name: multicard-4 name: multicard-4

View File

@@ -70,6 +70,9 @@ from safetensors.torch import load_file
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_open_port from vllm.utils.network_utils import get_open_port
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
os.environ["VLLM_USE_MODELSCOPE"] = "True" os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -219,15 +222,6 @@ def main(
gpu_memory_utilization = 0.95, gpu_memory_utilization = 0.95,
enable_sleep_mode=enable_sleep_mode, enable_sleep_mode=enable_sleep_mode,
) )
model_path = model
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
patch_vllm_moe_model_weight_loader(runmodel)
sd = load_and_merge_safetensors(model_path)
runmodel.load_weights(sd.items())
print('load state dict done')
tp_ranks = get_tp_group().ranks
print(f'TP RANKS: {tp_ranks}')
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
if enable_sleep_mode: if enable_sleep_mode:
@@ -242,6 +236,20 @@ def main(
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
llm.wake_up() llm.wake_up()
model_path = model
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
patch_vllm_moe_model_weight_loader(runmodel)
sd = load_and_merge_safetensors(model_path)
runmodel.load_weights(sd.items())
print('load state dict done')
tp_ranks = get_tp_group().ranks
print(f'TP RANKS: {tp_ranks}')
vllm_config = llm.llm_engine.vllm_config.model_config
device = next(runmodel.parameters()).device
process_weights_after_loading(runmodel, vllm_config, device)
outputs_after_wakeup = llm.generate(prompts, sampling_params) outputs_after_wakeup = llm.generate(prompts, sampling_params)
if rank == 0: if rank == 0:
# cmp output # cmp output

View File

@@ -0,0 +1,74 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Run `pytest tests/multicard/test_offline_load_weight.py`.
"""
import os
import subprocess
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
MODELS = ["Qwen/Qwen3-30B-A3B"]
@pytest.mark.parametrize("model", MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
def test_offline_weight_load_and_sleepmode(model):
script = Path(
__file__
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
env = os.environ.copy()
cmd = [
sys.executable,
str(script),
"--model",
model,
"--tp-size",
"2",
"--node-size",
"1",
"--node-rank",
"0",
"--proc-per-node",
"2",
"--trust-remote-code",
"--enable-sleep-mode",
"--temperature",
"0",
"--model-weight-gib",
"0.8",
]
print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600,
)
output = proc.stdout.decode(errors='ignore')
print(output)
assert "Generated text:" in output
assert "Sleep and wake up successfully!!" in output
assert proc.returncode == 0

View File

@@ -25,8 +25,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.fused_moe import ( from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list, from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
unified_apply_mlp) unified_apply_mlp)
from vllm_ascend.utils import AscendDeviceType, adapt_patch from vllm_ascend.utils import AscendDeviceType, adapt_patch
@@ -595,39 +594,3 @@ class TestUnifiedApplyMLP(TestBase):
self.assertTrue(mock_forward_context.with_quant) self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states_shape) self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16) self.assertEqual(result.dtype, torch.bfloat16)
class TestLoadWeight(TestBase):
def test_load_w13_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
def test_load_w2_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 4)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)
expert_data = torch.randn(4, 128)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)

View File

@@ -281,9 +281,22 @@ class TestNPUWorker(TestBase):
mock_allocator = MagicMock() mock_allocator = MagicMock()
mock_allocator_class.get_instance.return_value = mock_allocator mock_allocator_class.get_instance.return_value = mock_allocator
mock_hidden_size = MagicMock()
mock_hf_config = MagicMock()
mock_hf_config.hidden_size = mock_hidden_size
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_model_runner = MagicMock()
mock_model_runner.model = MagicMock()
# Create worker mock # Create worker mock
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
worker = NPUWorker() worker = NPUWorker()
worker.model_runner = mock_model_runner
worker.vllm_config = mock_vllm_config
worker._sleep_saved_buffers = {} worker._sleep_saved_buffers = {}
# Test wake_up method # Test wake_up method
worker.wake_up(tags=["test_tag"]) worker.wake_up(tags=["test_tag"])

View File

@@ -56,29 +56,18 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
super().__init__(moe=moe) super().__init__(moe=moe)
self.dynamic_eplb = get_ascend_config().dynamic_eplb self.dynamic_eplb = get_ascend_config().dynamic_eplb
self.transpose = True
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod, super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer) self).process_weights_after_loading(layer)
if self.transpose:
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data,
requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous() 1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
self.transpose = False w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
else: 1, 2).contiguous()
w13_data = self._maybe_pad_weight(layer.w13_weight.data) layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
layer.w13_weight = torch.nn.Parameter(w13_data,
requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz( if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
): ):
@@ -389,61 +378,6 @@ class AscendFusedMoE(FusedMoE):
return final_hidden_states return final_hidden_states
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
# Ensure training and inference weight shapes match during RL weight updates
if (len(loaded_weight.shape) >= 2 and len(expert_data.shape) >= 2 and \
loaded_weight.shape[1] != expert_data.shape[1] and \
loaded_weight.shape[0] != expert_data.shape[0]
):
shard_dim = int(not shard_dim)
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
return loaded_weight, shard_dim
def _load_w13(self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim] // 2
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):

View File

@@ -176,9 +176,28 @@ class NPUWorker(WorkerBase):
allocator = CaMemAllocator.get_instance() allocator = CaMemAllocator.get_instance()
allocator.wake_up(tags=tags) allocator.wake_up(tags=tags)
hidden_size = self.vllm_config.model_config.hf_config.hidden_size
model = self.model_runner.model
for name, param in model.named_parameters():
if 'w2_weight' in name and param.shape[2] == hidden_size:
parts = name.split('.')
param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1]))
w2_data = param.transpose(1, 2)
w2_data = torch.nn.Parameter(w2_data, requires_grad=False)
setattr(parent_module, param_name, w2_data)
elif 'w13_weight' in name and param.shape[1] == hidden_size:
parts = name.split('.')
param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1]))
w13_data = param.transpose(1, 2)
w13_data = torch.nn.Parameter(w13_data, requires_grad=False)
setattr(parent_module, param_name, w13_data)
# Restore the buffers after level 2 sleep # Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers): if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers(): for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers: if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data) buffer.data.copy_(self._sleep_saved_buffers[name].data)