[MOE]move weight transpose to wakeup for RL secnarios (#4626)
### What this PR does / why we need it?
In reinforcement learning scenarios, the current inference applies a
transpose operation to the weights. For a cleaner architecture, the
weight transpose module was moved to wakeup.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -70,6 +70,9 @@ from safetensors.torch import load_file
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
from vllm.model_executor.model_loader.utils import \
|
||||
process_weights_after_loading
|
||||
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
@@ -219,15 +222,6 @@ def main(
|
||||
gpu_memory_utilization = 0.95,
|
||||
enable_sleep_mode=enable_sleep_mode,
|
||||
)
|
||||
model_path = model
|
||||
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
patch_vllm_moe_model_weight_loader(runmodel)
|
||||
sd = load_and_merge_safetensors(model_path)
|
||||
runmodel.load_weights(sd.items())
|
||||
print('load state dict done')
|
||||
tp_ranks = get_tp_group().ranks
|
||||
print(f'TP RANKS: {tp_ranks}')
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
if enable_sleep_mode:
|
||||
@@ -242,6 +236,20 @@ def main(
|
||||
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
||||
|
||||
llm.wake_up()
|
||||
|
||||
model_path = model
|
||||
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
patch_vllm_moe_model_weight_loader(runmodel)
|
||||
sd = load_and_merge_safetensors(model_path)
|
||||
runmodel.load_weights(sd.items())
|
||||
print('load state dict done')
|
||||
tp_ranks = get_tp_group().ranks
|
||||
print(f'TP RANKS: {tp_ranks}')
|
||||
|
||||
vllm_config = llm.llm_engine.vllm_config.model_config
|
||||
device = next(runmodel.parameters()).device
|
||||
process_weights_after_loading(runmodel, vllm_config, device)
|
||||
|
||||
outputs_after_wakeup = llm.generate(prompts, sampling_params)
|
||||
if rank == 0:
|
||||
# cmp output
|
||||
|
||||
Reference in New Issue
Block a user