[feature] Rework Ascend NPU graph support (#9350)

Co-authored-by: ronnie_zheng <zl19940307@163.com>
Co-authored-by: yezhifeng (D) <y00897525@china.huawei.com>
Co-authored-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: Maksim <makcum888e@mail.ru>
Co-authored-by: ssshinigami <44640852+ssshinigami@users.noreply.github.com>
This commit is contained in:
Even Zhou
2025-08-20 11:32:27 +08:00
committed by GitHub
parent f515449582
commit 3680d6f88b
18 changed files with 546 additions and 81 deletions

View File

@@ -0,0 +1,36 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with cuda graph and torch.compile."""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.model_executor.graph_runner import GraphRunner
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
class CudaGraphRunner(GraphRunner):
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def __init__(self, model_runner: ModelRunner):
# Parse args
super().__init__(model_runner)
def _create_device_graph(self):
return torch.cuda.CUDAGraph()

View File

@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with cuda graph and torch.compile."""
"""Run the model with device graph and torch.compile."""
from __future__ import annotations
@@ -221,7 +221,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
return capture_bs, compile_bs
# Reuse this memory pool across all cuda graph runners.
# Reuse this memory pool across all device graph runners.
global_graph_memory_pool = None
@@ -234,12 +234,14 @@ def set_global_graph_memory_pool(val):
global_graph_memory_pool = val
class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
class GraphRunner:
"""A GraphRunner is a base class to run the forward pass of a model with device graph and torch.compile."""
def __init__(self, model_runner: ModelRunner):
# Parse args
self.model_runner = model_runner
self.device = model_runner.device
self.device_module = torch.get_device_module(self.device)
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
@@ -265,7 +267,7 @@ class CudaGraphRunner:
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
rank0_log(f"Capture graph bs {self.capture_bs}")
self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1
@@ -305,13 +307,15 @@ class CudaGraphRunner:
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
# Graph inputs
with torch.device("cuda"):
with torch.device(self.device):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.out_cache_loc = torch.zeros(
(self.max_num_token,), dtype=self._cache_loc_dtype()
)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
@@ -366,12 +370,12 @@ class CudaGraphRunner:
* self.num_tokens_per_bs
),
dtype=torch.bool,
device="cuda",
device=self.device,
)
self.next_token_logits_buffer = torch.zeros(
(self.max_num_token, self.model_runner.model_config.vocab_size),
dtype=torch.float,
device="cuda",
device=self.device,
)
# Capture
@@ -380,9 +384,12 @@ class CudaGraphRunner:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
f"Capture device graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
)
def _cache_loc_dtype(self):
return torch.int64
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
cuda_graph_bs = (
@@ -502,8 +509,16 @@ class CudaGraphRunner:
)
logger.info(log_message)
def _capture_graph(self, graph, pool, stream, run_once_fn):
with self.device_module.graph(graph, pool=pool, stream=stream):
out = run_once_fn()
return out
def _create_device_graph(self):
pass
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
graph = self._create_device_graph()
stream = self.stream
num_tokens = bs * self.num_tokens_per_bs
@@ -643,19 +658,17 @@ class CudaGraphRunner:
return logits_output_or_pp_proxy_tensors
for _ in range(2):
torch.cuda.synchronize()
self.device_module.synchronize()
self.model_runner.tp_group.barrier()
run_once()
if get_global_graph_memory_pool() is None:
set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
# Set graph pool id globally to be able to use symmetric memory
set_graph_pool_id(get_global_graph_memory_pool())
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
out = self._capture_graph(
graph, get_global_graph_memory_pool(), stream, run_once
)
return graph, out
@@ -837,7 +850,7 @@ class CudaGraphRunner:
return spec_info
CUDA_GRAPH_CAPTURE_FAILED_MSG = (
GRAPH_CAPTURE_FAILED_MSG = (
"Possible solutions:\n"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"

View File

@@ -89,8 +89,11 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
# TODO(iforgetmyname): Renaming on the way
from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.utils import set_default_torch_dtype
@@ -341,9 +344,12 @@ class ModelRunner:
if self.device == "cuda":
self.init_cublas()
self.init_attention_backend()
self.init_cuda_graphs()
self.init_device_graphs()
elif self.device == "npu":
self.init_attention_backend()
self.init_device_graphs()
else:
self.cuda_graph_runner = None
self.graph_runner = None
self.cuda_graph_mem_usage = 0
self.init_attention_backend()
@@ -917,7 +923,8 @@ class ModelRunner:
)
# We need to get device after patch otherwise the device would be wrong
infered_device = torch.cuda.current_device()
self.device_module = torch.get_device_module(self.device)
infered_device = self.device_module.current_device()
named_tensors = [
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
@@ -1585,9 +1592,9 @@ class ModelRunner:
.cuda()
)
def init_cuda_graphs(self):
def init_device_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None
self.graph_runner = None
self.cuda_graph_mem_usage = 0
if not self.is_generation:
@@ -1602,8 +1609,9 @@ class ModelRunner:
logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self)
self.graph_runner = (
CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.cuda_graph_mem_usage = before_mem - after_mem
logger.info(
@@ -1755,11 +1763,11 @@ class ModelRunner:
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph()
and self.cuda_graph_runner
and self.cuda_graph_runner.can_run(forward_batch)
and self.graph_runner
and self.graph_runner.can_run(forward_batch)
)
if can_run_cuda_graph:
ret = self.cuda_graph_runner.replay(
ret = self.graph_runner.replay(
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,

View File

@@ -0,0 +1,94 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with npu graph and torch.compile."""
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING
import torch
from sglang.srt.model_executor.graph_runner import GraphRunner
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
class NPUGraphRunner(GraphRunner):
"""A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
def __init__(self, model_runner: ModelRunner):
super().__init__(model_runner)
def _create_device_graph(self):
return torch.npu.NPUGraph()
def _capture_graph(self, graph, pool, stream, run_once_fn):
with torch.npu.graph(
graph,
pool=pool,
stream=stream,
auto_dispatch_capture=True,
):
out = run_once_fn()
return out
def _update_inputs(self, seq_lens):
self.graphs[self.bs].update(
cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}]
)
def _cache_loc_dtype(self):
return torch.int32
def replay(
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
if not skip_attn_backend_init:
self.replay_prepare(forward_batch, pp_proxy_tensors)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
thread.start()
self.graphs[self.bs].replay()
thread.join()
output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
return LogitsProcessorOutput(
next_token_logits=output.next_token_logits[: self.raw_num_token],
hidden_states=(
output.hidden_states[: self.raw_num_token]
if output.hidden_states is not None
else None
),
)
else:
assert isinstance(output, PPProxyTensors)
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})