[main] adjust the position of warm_up_atb (#2823)

### What this PR does / why we need it?
Adjust the position of warm_up_atb.

### Does this PR introduce _any_ user-facing change?
not involved

### How was this patch tested?
CI passed with existing test.

- vLLM version: main
- vLLM main:
b23fb78623

Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
huangxialu
2025-09-10 14:06:38 +08:00
committed by GitHub
parent 22b425765a
commit 88d7af62be
3 changed files with 21 additions and 8 deletions

View File

@@ -20,7 +20,6 @@
from typing import Optional, Union
import torch
import torch_npu
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
@@ -277,11 +276,6 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache) may cause performance degradation at runtime.
x = torch.rand((2, 4), dtype=torch.float16).npu()
weight = torch.rand((2, 4), dtype=torch.float16).npu()
c = torch.rand((4, 4), dtype=torch.float32).npu()
torch_npu._npu_matmul_add_fp32(x, weight, c)
def forward(
self,

View File

@@ -250,10 +250,19 @@ class NPUWorker(WorkerBase):
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
# may cause performance degradation at runtime.
self._warm_up_atb()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
NPUPlatform.seed_everything(self.model_config.seed)
def _warm_up_atb(self):
x = torch.rand((2, 4), dtype=torch.float16).npu()
weight = torch.rand((2, 4), dtype=torch.float16).npu()
c = torch.rand((4, 4), dtype=torch.float32).npu()
torch_npu._npu_matmul_add_fp32(x, weight, c)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()