[feature] Rework Ascend NPU graph support (#9350)

Co-authored-by: ronnie_zheng <zl19940307@163.com>
Co-authored-by: yezhifeng (D) <y00897525@china.huawei.com>
Co-authored-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: Maksim <makcum888e@mail.ru>
Co-authored-by: ssshinigami <44640852+ssshinigami@users.noreply.github.com>
This commit is contained in:
Even Zhou
2025-08-20 11:32:27 +08:00
committed by GitHub
parent f515449582
commit 3680d6f88b
18 changed files with 546 additions and 81 deletions

View File

@@ -6,20 +6,22 @@ from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
# TODO(iforgetmyname): Renaming on the way
from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.model_executor.graph_runner import (
GRAPH_CAPTURE_FAILED_MSG,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool,
set_torch_compile_config,
)
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.utils import (
require_attn_tp_gather,
@@ -121,7 +123,7 @@ class EAGLEDraftCudaGraphRunner:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
)
def can_run(self, forward_batch: ForwardBatch):

View File

@@ -6,9 +6,16 @@ from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
# TODO(iforgetmyname): Renaming on the way
from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.model_executor.graph_runner import (
GRAPH_CAPTURE_FAILED_MSG,
LogitsProcessorOutput,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
@@ -16,11 +23,6 @@ from sglang.srt.model_executor.cuda_graph_runner import (
set_global_graph_memory_pool,
set_torch_compile_config,
)
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.utils import (
require_attn_tp_gather,
@@ -149,7 +151,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
)
def can_run(self, forward_batch: ForwardBatch):