[feature] Ascend NPU graph support (#8027)
Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: yezhifeng (D) <y00897525@china.huawei.com> Co-authored-by: anon189Ty <Stari_Falcon@outlook.com> Co-authored-by: Maksim <makcum888e@mail.ru> Co-authored-by: ssshinigami <44640852+ssshinigami@users.noreply.github.com>
This commit is contained in:
@@ -6,20 +6,20 @@ from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.model_executor.graph_runner import (
|
||||
GRAPH_CAPTURE_FAILED_MSG,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
model_capture_mode,
|
||||
set_global_graph_memory_pool,
|
||||
set_torch_compile_config,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
@@ -121,7 +121,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
|
||||
@@ -6,9 +6,14 @@ from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.model_executor.graph_runner import (
|
||||
GRAPH_CAPTURE_FAILED_MSG,
|
||||
LogitsProcessorOutput,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
@@ -16,11 +21,6 @@ from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
set_global_graph_memory_pool,
|
||||
set_torch_compile_config,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
@@ -149,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
|
||||
Reference in New Issue
Block a user