[MOE]move weight transpose to wakeup for RL secnarios (#4626)

### What this PR does / why we need it?
In reinforcement learning scenarios, the current inference applies a
transpose operation to the weights. For a cleaner architecture, the
weight transpose module was moved to wakeup.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
lhp-deep
2025-12-08 20:34:52 +08:00
committed by GitHub
parent 58db21f56a
commit b230e7e987
7 changed files with 132 additions and 120 deletions

View File

@@ -70,6 +70,9 @@ from safetensors.torch import load_file
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_open_port
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -219,15 +222,6 @@ def main(
gpu_memory_utilization = 0.95,
enable_sleep_mode=enable_sleep_mode,
)
model_path = model
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
patch_vllm_moe_model_weight_loader(runmodel)
sd = load_and_merge_safetensors(model_path)
runmodel.load_weights(sd.items())
print('load state dict done')
tp_ranks = get_tp_group().ranks
print(f'TP RANKS: {tp_ranks}')
outputs = llm.generate(prompts, sampling_params)
if enable_sleep_mode:
@@ -242,6 +236,20 @@ def main(
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
llm.wake_up()
model_path = model
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
patch_vllm_moe_model_weight_loader(runmodel)
sd = load_and_merge_safetensors(model_path)
runmodel.load_weights(sd.items())
print('load state dict done')
tp_ranks = get_tp_group().ranks
print(f'TP RANKS: {tp_ranks}')
vllm_config = llm.llm_engine.vllm_config.model_config
device = next(runmodel.parameters()).device
process_weights_after_loading(runmodel, vllm_config, device)
outputs_after_wakeup = llm.generate(prompts, sampling_params)
if rank == 0:
# cmp output