[MOE]move weight transpose to wakeup for RL secnarios (#4626)
### What this PR does / why we need it?
In reinforcement learning scenarios, the current inference applies a
transpose operation to the weights. For a cleaner architecture, the
weight transpose module was moved to wakeup.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -205,6 +205,7 @@ jobs:
|
|||||||
pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
|
pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
|
||||||
pytest -sv tests/e2e/multicard/test_prefix_caching.py
|
pytest -sv tests/e2e/multicard/test_prefix_caching.py
|
||||||
pytest -sv tests/e2e/multicard/test_qwen3_moe.py
|
pytest -sv tests/e2e/multicard/test_qwen3_moe.py
|
||||||
|
pytest -sv tests/e2e/multicard/test_offline_weight_load.py
|
||||||
|
|
||||||
e2e-4-cards:
|
e2e-4-cards:
|
||||||
name: multicard-4
|
name: multicard-4
|
||||||
|
|||||||
@@ -70,6 +70,9 @@ from safetensors.torch import load_file
|
|||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.utils.network_utils import get_open_port
|
from vllm.utils.network_utils import get_open_port
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader.utils import \
|
||||||
|
process_weights_after_loading
|
||||||
|
|
||||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
@@ -219,15 +222,6 @@ def main(
|
|||||||
gpu_memory_utilization = 0.95,
|
gpu_memory_utilization = 0.95,
|
||||||
enable_sleep_mode=enable_sleep_mode,
|
enable_sleep_mode=enable_sleep_mode,
|
||||||
)
|
)
|
||||||
model_path = model
|
|
||||||
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
|
||||||
patch_vllm_moe_model_weight_loader(runmodel)
|
|
||||||
sd = load_and_merge_safetensors(model_path)
|
|
||||||
runmodel.load_weights(sd.items())
|
|
||||||
print('load state dict done')
|
|
||||||
tp_ranks = get_tp_group().ranks
|
|
||||||
print(f'TP RANKS: {tp_ranks}')
|
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
if enable_sleep_mode:
|
if enable_sleep_mode:
|
||||||
@@ -242,6 +236,20 @@ def main(
|
|||||||
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
||||||
|
|
||||||
llm.wake_up()
|
llm.wake_up()
|
||||||
|
|
||||||
|
model_path = model
|
||||||
|
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||||
|
patch_vllm_moe_model_weight_loader(runmodel)
|
||||||
|
sd = load_and_merge_safetensors(model_path)
|
||||||
|
runmodel.load_weights(sd.items())
|
||||||
|
print('load state dict done')
|
||||||
|
tp_ranks = get_tp_group().ranks
|
||||||
|
print(f'TP RANKS: {tp_ranks}')
|
||||||
|
|
||||||
|
vllm_config = llm.llm_engine.vllm_config.model_config
|
||||||
|
device = next(runmodel.parameters()).device
|
||||||
|
process_weights_after_loading(runmodel, vllm_config, device)
|
||||||
|
|
||||||
outputs_after_wakeup = llm.generate(prompts, sampling_params)
|
outputs_after_wakeup = llm.generate(prompts, sampling_params)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# cmp output
|
# cmp output
|
||||||
|
|||||||
74
tests/e2e/multicard/test_offline_weight_load.py
Normal file
74
tests/e2e/multicard/test_offline_weight_load.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
Run `pytest tests/multicard/test_offline_load_weight.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
MODELS = ["Qwen/Qwen3-30B-A3B"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
|
||||||
|
def test_offline_weight_load_and_sleepmode(model):
|
||||||
|
script = Path(
|
||||||
|
__file__
|
||||||
|
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
|
||||||
|
env = os.environ.copy()
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(script),
|
||||||
|
"--model",
|
||||||
|
model,
|
||||||
|
"--tp-size",
|
||||||
|
"2",
|
||||||
|
"--node-size",
|
||||||
|
"1",
|
||||||
|
"--node-rank",
|
||||||
|
"0",
|
||||||
|
"--proc-per-node",
|
||||||
|
"2",
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--enable-sleep-mode",
|
||||||
|
"--temperature",
|
||||||
|
"0",
|
||||||
|
"--model-weight-gib",
|
||||||
|
"0.8",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running subprocess: {' '.join(cmd)}")
|
||||||
|
proc = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
output = proc.stdout.decode(errors='ignore')
|
||||||
|
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
assert "Generated text:" in output
|
||||||
|
assert "Sleep and wake up successfully!!" in output
|
||||||
|
assert proc.returncode == 0
|
||||||
@@ -25,8 +25,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
|||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.fused_moe.fused_moe import (
|
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||||
AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
|
|
||||||
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
|
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
|
||||||
unified_apply_mlp)
|
unified_apply_mlp)
|
||||||
from vllm_ascend.utils import AscendDeviceType, adapt_patch
|
from vllm_ascend.utils import AscendDeviceType, adapt_patch
|
||||||
@@ -595,39 +594,3 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
self.assertTrue(mock_forward_context.with_quant)
|
self.assertTrue(mock_forward_context.with_quant)
|
||||||
self.assertEqual(result.shape, hidden_states_shape)
|
self.assertEqual(result.shape, hidden_states_shape)
|
||||||
self.assertEqual(result.dtype, torch.bfloat16)
|
self.assertEqual(result.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
class TestLoadWeight(TestBase):
|
|
||||||
|
|
||||||
def test_load_w13_transpose(self):
|
|
||||||
with patch.object(AscendFusedMoE, "__init__",
|
|
||||||
lambda self, *args, **kwargs: None):
|
|
||||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
|
||||||
|
|
||||||
expert_data = torch.randn(128, 8)
|
|
||||||
loaded_weight = torch.randn(128, 4)
|
|
||||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
|
||||||
|
|
||||||
expert_data = torch.randn(8, 128)
|
|
||||||
loaded_weight = torch.randn(128, 4)
|
|
||||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
|
||||||
|
|
||||||
expert_data = torch.randn(128, 8)
|
|
||||||
loaded_weight = torch.randn(128, 4)
|
|
||||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
|
||||||
|
|
||||||
expert_data = torch.randn(8, 128)
|
|
||||||
loaded_weight = torch.randn(128, 4)
|
|
||||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
|
||||||
|
|
||||||
def test_load_w2_transpose(self):
|
|
||||||
with patch.object(AscendFusedMoE, "__init__",
|
|
||||||
lambda self, *args, **kwargs: None):
|
|
||||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
|
||||||
expert_data = torch.randn(128, 4)
|
|
||||||
loaded_weight = torch.randn(128, 8)
|
|
||||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
|
||||||
|
|
||||||
expert_data = torch.randn(4, 128)
|
|
||||||
loaded_weight = torch.randn(128, 8)
|
|
||||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
|
||||||
|
|||||||
@@ -281,9 +281,22 @@ class TestNPUWorker(TestBase):
|
|||||||
mock_allocator = MagicMock()
|
mock_allocator = MagicMock()
|
||||||
mock_allocator_class.get_instance.return_value = mock_allocator
|
mock_allocator_class.get_instance.return_value = mock_allocator
|
||||||
|
|
||||||
|
mock_hidden_size = MagicMock()
|
||||||
|
mock_hf_config = MagicMock()
|
||||||
|
mock_hf_config.hidden_size = mock_hidden_size
|
||||||
|
mock_model_config = MagicMock()
|
||||||
|
mock_model_config.hf_config = mock_hf_config
|
||||||
|
mock_vllm_config = MagicMock()
|
||||||
|
mock_vllm_config.model_config = mock_model_config
|
||||||
|
|
||||||
|
mock_model_runner = MagicMock()
|
||||||
|
mock_model_runner.model = MagicMock()
|
||||||
|
|
||||||
# Create worker mock
|
# Create worker mock
|
||||||
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
||||||
worker = NPUWorker()
|
worker = NPUWorker()
|
||||||
|
worker.model_runner = mock_model_runner
|
||||||
|
worker.vllm_config = mock_vllm_config
|
||||||
worker._sleep_saved_buffers = {}
|
worker._sleep_saved_buffers = {}
|
||||||
# Test wake_up method
|
# Test wake_up method
|
||||||
worker.wake_up(tags=["test_tag"])
|
worker.wake_up(tags=["test_tag"])
|
||||||
|
|||||||
@@ -56,29 +56,18 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
|
|
||||||
super().__init__(moe=moe)
|
super().__init__(moe=moe)
|
||||||
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
||||||
self.transpose = True
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
super(UnquantizedFusedMoEMethod,
|
super(UnquantizedFusedMoEMethod,
|
||||||
self).process_weights_after_loading(layer)
|
self).process_weights_after_loading(layer)
|
||||||
if self.transpose:
|
|
||||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
|
||||||
1, 2).contiguous()
|
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_data,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||||
1, 2).contiguous()
|
1, 2).contiguous()
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||||
|
|
||||||
self.transpose = False
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||||
else:
|
1, 2).contiguous()
|
||||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_data,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
|
||||||
|
|
||||||
if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
|
if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
|
||||||
):
|
):
|
||||||
@@ -389,61 +378,6 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
|
|
||||||
# Ensure training and inference weight shapes match during RL weight updates
|
|
||||||
if (len(loaded_weight.shape) >= 2 and len(expert_data.shape) >= 2 and \
|
|
||||||
loaded_weight.shape[1] != expert_data.shape[1] and \
|
|
||||||
loaded_weight.shape[0] != expert_data.shape[0]
|
|
||||||
):
|
|
||||||
shard_dim = int(not shard_dim)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
|
|
||||||
return loaded_weight, shard_dim
|
|
||||||
|
|
||||||
def _load_w13(self,
|
|
||||||
expert_data: torch.Tensor,
|
|
||||||
shard_dim: int,
|
|
||||||
shard_id: str,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
tp_rank: int,
|
|
||||||
load_full: bool = False):
|
|
||||||
# Index the loaded weight for tp sharding.
|
|
||||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
|
||||||
loaded_weight, shard_dim = self.transpose_weight(
|
|
||||||
loaded_weight, expert_data, shard_dim)
|
|
||||||
shard_size = expert_data.shape[shard_dim] // 2
|
|
||||||
if not load_full:
|
|
||||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
|
||||||
shard_size * tp_rank,
|
|
||||||
shard_size)
|
|
||||||
# Narrow parameter and load.
|
|
||||||
# w1, gate_proj: Load into first logical weight of w13.
|
|
||||||
if shard_id == "w1":
|
|
||||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
|
||||||
# w3, up_proj: Load into second logical weight of w13.
|
|
||||||
else:
|
|
||||||
assert shard_id == "w3"
|
|
||||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
|
||||||
expert_data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
def _load_w2(self,
|
|
||||||
expert_data: torch.Tensor,
|
|
||||||
shard_dim: int,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
tp_rank: int,
|
|
||||||
load_full: bool = False):
|
|
||||||
# Index the loaded weight for tp sharding.
|
|
||||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
|
||||||
# Narrow parameter and load.
|
|
||||||
loaded_weight, shard_dim = self.transpose_weight(
|
|
||||||
loaded_weight, expert_data, shard_dim)
|
|
||||||
shard_size = expert_data.shape[shard_dim]
|
|
||||||
if not load_full:
|
|
||||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
|
||||||
shard_size * tp_rank,
|
|
||||||
shard_size)
|
|
||||||
# w2, down_proj: Load into only logical weight of w2.
|
|
||||||
expert_data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
|
|
||||||
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||||
|
|
||||||
|
|||||||
@@ -176,9 +176,28 @@ class NPUWorker(WorkerBase):
|
|||||||
allocator = CaMemAllocator.get_instance()
|
allocator = CaMemAllocator.get_instance()
|
||||||
allocator.wake_up(tags=tags)
|
allocator.wake_up(tags=tags)
|
||||||
|
|
||||||
|
hidden_size = self.vllm_config.model_config.hf_config.hidden_size
|
||||||
|
model = self.model_runner.model
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if 'w2_weight' in name and param.shape[2] == hidden_size:
|
||||||
|
parts = name.split('.')
|
||||||
|
param_name = parts[-1]
|
||||||
|
parent_module = model.get_submodule(".".join(parts[:-1]))
|
||||||
|
|
||||||
|
w2_data = param.transpose(1, 2)
|
||||||
|
w2_data = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||||
|
setattr(parent_module, param_name, w2_data)
|
||||||
|
elif 'w13_weight' in name and param.shape[1] == hidden_size:
|
||||||
|
parts = name.split('.')
|
||||||
|
param_name = parts[-1]
|
||||||
|
parent_module = model.get_submodule(".".join(parts[:-1]))
|
||||||
|
|
||||||
|
w13_data = param.transpose(1, 2)
|
||||||
|
w13_data = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||||
|
setattr(parent_module, param_name, w13_data)
|
||||||
|
|
||||||
# Restore the buffers after level 2 sleep
|
# Restore the buffers after level 2 sleep
|
||||||
if len(self._sleep_saved_buffers):
|
if len(self._sleep_saved_buffers):
|
||||||
model = self.model_runner.model
|
|
||||||
for name, buffer in model.named_buffers():
|
for name, buffer in model.named_buffers():
|
||||||
if name in self._sleep_saved_buffers:
|
if name in self._sleep_saved_buffers:
|
||||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||||
|
|||||||
Reference in New Issue
Block a user