From de2dd73831bdd22b5f2946909e6827ece4712ba8 Mon Sep 17 00:00:00 2001 From: Even Zhou Date: Wed, 20 Aug 2025 15:35:10 +0800 Subject: [PATCH] Revert "[feature] Rework Ascend NPU graph support" (#9385) --- .../benchmark_torch_compile_fused_moe.py | 2 +- .../sglang/srt/distributed/parallel_state.py | 14 +- .../srt/layers/attention/ascend_backend.py | 154 +++--------------- python/sglang/srt/mem_cache/memory_pool.py | 2 +- .../{graph_runner.py => cuda_graph_runner.py} | 51 +++--- .../model_executor/cuda_graph_runner_impl.py | 36 ---- .../sglang/srt/model_executor/model_runner.py | 30 ++-- .../srt/model_executor/npu_graph_runner.py | 94 ----------- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/models/glm4_moe.py | 2 +- python/sglang/srt/models/mllama.py | 2 +- python/sglang/srt/models/qwen3.py | 2 +- python/sglang/srt/models/qwen3_moe.py | 2 +- .../eagle_draft_cuda_graph_runner.py | 20 +-- .../eagle_draft_extend_cuda_graph_runner.py | 20 +-- test/srt/ascend/test_ascend_graph_tp1_bf16.py | 95 ----------- test/srt/ascend/test_ascend_graph_tp2_bf16.py | 97 ----------- test/srt/run_suite.py | 2 - 18 files changed, 81 insertions(+), 546 deletions(-) rename python/sglang/srt/model_executor/{graph_runner.py => cuda_graph_runner.py} (96%) delete mode 100644 python/sglang/srt/model_executor/cuda_graph_runner_impl.py delete mode 100644 python/sglang/srt/model_executor/npu_graph_runner.py delete mode 100644 test/srt/ascend/test_ascend_graph_tp1_bf16.py delete mode 100644 test/srt/ascend/test_ascend_graph_tp2_bf16.py diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index 1fcea7cd4..2b4faa24b 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -9,7 +9,7 @@ from transformers import AutoConfig from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_moe as fused_moe_triton, ) -from sglang.srt.model_executor.graph_runner import set_torch_compile_config +from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config def get_model_config(model_name: str, tp_size: int): diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index a8a8d20f6..286618d6b 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -55,7 +55,7 @@ _is_npu = is_npu() @dataclass class GraphCaptureContext: - stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream + stream: torch.cuda.Stream TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) @@ -252,13 +252,9 @@ class GroupCoordinator: if is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") - elif _is_npu: - self.device = torch.device(f"npu:{local_rank}") else: self.device = torch.device("cpu") - self.device_module = torch.get_device_module(self.device) - self.use_pynccl = use_pynccl self.use_pymscclpp = use_pymscclpp self.use_custom_allreduce = use_custom_allreduce @@ -406,7 +402,7 @@ class GroupCoordinator: self, graph_capture_context: Optional[GraphCaptureContext] = None ): if graph_capture_context is None: - stream = self.device_module.Stream() + stream = torch.cuda.Stream() graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream @@ -417,11 +413,11 @@ class GroupCoordinator: # ensure all initialization operations complete before attempting to # capture the graph on another stream - curr_stream = self.device_module.current_stream() + curr_stream = torch.cuda.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) - with self.device_module.stream(stream), maybe_ca_context: + with torch.cuda.stream(stream), maybe_ca_context: # In graph mode, we have to be very careful about the collective # operations. The current status is: # allreduce \ Mode | Eager | Graph | @@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ) elif hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.empty_cache() - elif hasattr(torch, "npu") and torch.npu.is_available(): - torch.npu.empty_cache() def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index c1f4c2785..020f04dcd 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional import torch import torch_npu @@ -27,7 +27,6 @@ class ForwardMetadata: # seq len inputs extend_seq_lens_cpu_int: Optional[torch.Tensor] = None seq_lens_cpu_int: Optional[torch.Tensor] = None - seq_lens_cpu_list: Optional[List[int]] = None class AscendAttnBackend(AttentionBackend): @@ -52,7 +51,7 @@ class AscendAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() - self.forward_metadata = None + self.forward_metadata = ForwardMetadata() self.device = model_runner.device self.gen_attention_mask(128, model_runner.dtype) self.page_size = model_runner.page_size @@ -61,15 +60,9 @@ class AscendAttnBackend(AttentionBackend): self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.native_attn = TorchNativeAttnBackend(model_runner) - self.graph_metadata = {} - self.max_context_len = model_runner.model_config.context_len - self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.graph_mode = False def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" - self.forward_metadata = ForwardMetadata() - self.forward_metadata.block_tables = ( forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : forward_batch.seq_lens.max() @@ -82,63 +75,6 @@ class AscendAttnBackend(AttentionBackend): ) self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() - self.graph_mode = False - - def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): - self.graph_metadata = { - "block_tables": torch.empty( - (max_bs, self.max_context_len // self.page_size), - dtype=torch.int32, - device=self.device, - ), - } - - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - num_tokens: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], - forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], - ): - metadata = ForwardMetadata() - - metadata.block_tables = self.graph_metadata["block_tables"][:bs, :] - metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist() - - self.graph_metadata[bs] = metadata - self.forward_metadata = metadata - - self.graph_mode = True - - def init_forward_metadata_replay_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor], - forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], - seq_lens_cpu: Optional[torch.Tensor], - ): - metadata = self.graph_metadata[bs] - max_len = seq_lens_cpu[:bs].max().item() - max_seq_pages = (max_len + self.page_size - 1) // self.page_size - - metadata.block_tables[:bs, :max_seq_pages].copy_( - self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size] - // self.page_size - ) - metadata.block_tables[:bs, max_seq_pages:].fill_(0) - metadata.block_tables[bs:, :].fill_(0) - - self.forward_metadata = metadata - - self.graph_mode = True - def get_cuda_graph_seq_len_fill_value(self): return 1 @@ -231,74 +167,28 @@ class AscendAttnBackend(AttentionBackend): layer, forward_batch.out_cache_loc, k, v ) if not self.use_mla: - if self.graph_mode: - k_cache = forward_batch.token_to_kv_pool.get_key_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) - query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) - num_tokens = query.shape[0] - workspace = ( - torch_npu._npu_fused_infer_attention_score_get_max_workspace( - query, - k_cache, - v_cache, - block_table=self.forward_metadata.block_tables, - block_size=self.page_size, - num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, - input_layout="BSH", - scale=layer.scaling, - actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, - ) - ) - output = torch.empty( - (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), - dtype=q.dtype, - device=q.device, - ) - softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) - torch_npu.npu_fused_infer_attention_score.out( - query, - k_cache, - v_cache, - block_table=self.forward_metadata.block_tables, - block_size=self.page_size, - num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, - input_layout="BSH", - scale=layer.scaling, - actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, - workspace=workspace, - out=[output, softmax_lse], - ) - else: - k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer( - layer.layer_id - ) + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) - query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) - num_tokens = query.shape[0] - output = torch.empty( - (num_tokens, layer.tp_q_head_num, layer.v_head_dim), - dtype=query.dtype, - device=query.device, - ) + query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + num_tokens = query.shape[0] + output = torch.empty( + (num_tokens, layer.tp_q_head_num, layer.v_head_dim), + dtype=query.dtype, + device=query.device, + ) - torch_npu._npu_paged_attention( - query=query, - key_cache=k_cache, - value_cache=v_cache, - num_heads=layer.tp_q_head_num, - num_kv_heads=layer.tp_k_head_num, - scale_value=layer.scaling, - block_table=self.forward_metadata.block_tables, - context_lens=self.forward_metadata.seq_lens_cpu_int, - out=output, - ) + torch_npu._npu_paged_attention( + query=query, + key_cache=k_cache, + value_cache=v_cache, + num_heads=layer.tp_q_head_num, + num_kv_heads=layer.tp_k_head_num, + scale_value=layer.scaling, + block_table=self.forward_metadata.block_tables, + context_lens=self.forward_metadata.seq_lens_cpu_int, + out=output, + ) return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) else: query = q.view(-1, layer.tp_q_head_num, layer.head_dim) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 07d7f5234..1653d4535 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -376,7 +376,7 @@ class MHATokenToKVPool(KVCache): v_scale: Optional[float] = None, layer_id_override: Optional[int] = None, ): - from sglang.srt.model_executor.graph_runner import get_is_capture_mode + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if layer_id_override is not None: layer_id = layer_id_override diff --git a/python/sglang/srt/model_executor/graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py similarity index 96% rename from python/sglang/srt/model_executor/graph_runner.py rename to python/sglang/srt/model_executor/cuda_graph_runner.py index afcb00b4e..cc87910ac 100644 --- a/python/sglang/srt/model_executor/graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Run the model with device graph and torch.compile.""" +"""Run the model with cuda graph and torch.compile.""" from __future__ import annotations @@ -221,7 +221,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): return capture_bs, compile_bs -# Reuse this memory pool across all device graph runners. +# Reuse this memory pool across all cuda graph runners. global_graph_memory_pool = None @@ -234,14 +234,12 @@ def set_global_graph_memory_pool(val): global_graph_memory_pool = val -class GraphRunner: - """A GraphRunner is a base class to run the forward pass of a model with device graph and torch.compile.""" +class CudaGraphRunner: + """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" def __init__(self, model_runner: ModelRunner): # Parse args self.model_runner = model_runner - self.device = model_runner.device - self.device_module = torch.get_device_module(self.device) self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile @@ -267,7 +265,7 @@ class GraphRunner: # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) - rank0_log(f"Capture graph bs {self.capture_bs}") + rank0_log(f"Capture cuda graph bs {self.capture_bs}") self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 @@ -307,15 +305,13 @@ class GraphRunner: self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) # Graph inputs - with torch.device(self.device): + with torch.device("cuda"): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.out_cache_loc = torch.zeros( - (self.max_num_token,), dtype=self._cache_loc_dtype() - ) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) @@ -370,12 +366,12 @@ class GraphRunner: * self.num_tokens_per_bs ), dtype=torch.bool, - device=self.device, + device="cuda", ) self.next_token_logits_buffer = torch.zeros( (self.max_num_token, self.model_runner.model_config.vocab_size), dtype=torch.float, - device=self.device, + device="cuda", ) # Capture @@ -384,12 +380,9 @@ class GraphRunner: self.capture() except RuntimeError as e: raise Exception( - f"Capture device graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}" + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) - def _cache_loc_dtype(self): - return torch.int64 - def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: cuda_graph_bs = ( @@ -509,16 +502,8 @@ class GraphRunner: ) logger.info(log_message) - def _capture_graph(self, graph, pool, stream, run_once_fn): - with self.device_module.graph(graph, pool=pool, stream=stream): - out = run_once_fn() - return out - - def _create_device_graph(self): - pass - def capture_one_batch_size(self, bs: int, forward: Callable): - graph = self._create_device_graph() + graph = torch.cuda.CUDAGraph() stream = self.stream num_tokens = bs * self.num_tokens_per_bs @@ -658,17 +643,19 @@ class GraphRunner: return logits_output_or_pp_proxy_tensors for _ in range(2): - self.device_module.synchronize() + torch.cuda.synchronize() self.model_runner.tp_group.barrier() + run_once() if get_global_graph_memory_pool() is None: - set_global_graph_memory_pool(self.device_module.graph_pool_handle()) + set_global_graph_memory_pool(torch.cuda.graph_pool_handle()) # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) - out = self._capture_graph( - graph, get_global_graph_memory_pool(), stream, run_once - ) + with torch.cuda.graph( + graph, pool=get_global_graph_memory_pool(), stream=stream + ): + out = run_once() return graph, out @@ -850,7 +837,7 @@ class GraphRunner: return spec_info -GRAPH_CAPTURE_FAILED_MSG = ( +CUDA_GRAPH_CAPTURE_FAILED_MSG = ( "Possible solutions:\n" "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" diff --git a/python/sglang/srt/model_executor/cuda_graph_runner_impl.py b/python/sglang/srt/model_executor/cuda_graph_runner_impl.py deleted file mode 100644 index aeca8dcb7..000000000 --- a/python/sglang/srt/model_executor/cuda_graph_runner_impl.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Run the model with cuda graph and torch.compile.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import torch - -from sglang.srt.model_executor.graph_runner import GraphRunner - -if TYPE_CHECKING: - from sglang.srt.model_executor.model_runner import ModelRunner - - -class CudaGraphRunner(GraphRunner): - """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" - - def __init__(self, model_runner: ModelRunner): - # Parse args - super().__init__(model_runner) - - def _create_device_graph(self): - return torch.cuda.CUDAGraph() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 751bb7ded..6665458b8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -89,11 +89,8 @@ from sglang.srt.mem_cache.memory_pool import ( ReqToTokenPool, SWAKVPool, ) - -# TODO(iforgetmyname): Renaming on the way -from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner +from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype @@ -344,12 +341,9 @@ class ModelRunner: if self.device == "cuda": self.init_cublas() self.init_attention_backend() - self.init_device_graphs() - elif self.device == "npu": - self.init_attention_backend() - self.init_device_graphs() + self.init_cuda_graphs() else: - self.graph_runner = None + self.cuda_graph_runner = None self.cuda_graph_mem_usage = 0 self.init_attention_backend() @@ -923,8 +917,7 @@ class ModelRunner: ) # We need to get device after patch otherwise the device would be wrong - self.device_module = torch.get_device_module(self.device) - infered_device = self.device_module.current_device() + infered_device = torch.cuda.current_device() named_tensors = [ (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device)) @@ -1592,9 +1585,9 @@ class ModelRunner: .cuda() ) - def init_device_graphs(self): + def init_cuda_graphs(self): """Capture cuda graphs.""" - self.graph_runner = None + self.cuda_graph_runner = None self.cuda_graph_mem_usage = 0 if not self.is_generation: @@ -1609,9 +1602,8 @@ class ModelRunner: logger.info( f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) - self.graph_runner = ( - CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self) - ) + self.cuda_graph_runner = CudaGraphRunner(self) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) self.cuda_graph_mem_usage = before_mem - after_mem logger.info( @@ -1763,11 +1755,11 @@ class ModelRunner: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: can_run_cuda_graph = bool( forward_batch.forward_mode.is_cuda_graph() - and self.graph_runner - and self.graph_runner.can_run(forward_batch) + and self.cuda_graph_runner + and self.cuda_graph_runner.can_run(forward_batch) ) if can_run_cuda_graph: - ret = self.graph_runner.replay( + ret = self.cuda_graph_runner.replay( forward_batch, skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py deleted file mode 100644 index 582b5b7c6..000000000 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Run the model with npu graph and torch.compile.""" - -from __future__ import annotations - -import logging -import threading -from typing import TYPE_CHECKING - -import torch - -from sglang.srt.model_executor.graph_runner import GraphRunner - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from sglang.srt.model_executor.model_runner import ModelRunner - -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors - - -class NPUGraphRunner(GraphRunner): - """A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile.""" - - def __init__(self, model_runner: ModelRunner): - super().__init__(model_runner) - - def _create_device_graph(self): - return torch.npu.NPUGraph() - - def _capture_graph(self, graph, pool, stream, run_once_fn): - with torch.npu.graph( - graph, - pool=pool, - stream=stream, - auto_dispatch_capture=True, - ): - out = run_once_fn() - return out - - def _update_inputs(self, seq_lens): - self.graphs[self.bs].update( - cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}] - ) - - def _cache_loc_dtype(self): - return torch.int32 - - def replay( - self, - forward_batch: ForwardBatch, - skip_attn_backend_init: bool = False, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - if not skip_attn_backend_init: - self.replay_prepare(forward_batch, pp_proxy_tensors) - else: - # In speculative decoding, these two fields are still needed. - self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) - self.positions[: self.raw_num_token].copy_(forward_batch.positions) - - # Replay - seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs) - thread = threading.Thread(target=self._update_inputs, args=(seq_lens,)) - thread.start() - self.graphs[self.bs].replay() - thread.join() - - output = self.output_buffers[self.bs] - if isinstance(output, LogitsProcessorOutput): - return LogitsProcessorOutput( - next_token_logits=output.next_token_logits[: self.raw_num_token], - hidden_states=( - output.hidden_states[: self.raw_num_token] - if output.hidden_states is not None - else None - ), - ) - else: - assert isinstance(output, PPProxyTensors) - return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 37274e45b..eeebe1863 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1200,7 +1200,7 @@ class DeepseekV2AttentionMLA(nn.Module): forward_batch: ForwardBatch, zero_allocator: BumpAllocator, ): - from sglang.srt.model_executor.graph_runner import get_is_capture_mode + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if self.q_lora_rank is not None: if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm: diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index bf6ceaeb8..ab118ad9c 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -68,8 +68,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_executor.graph_runner import get_is_capture_mode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import ( DeepseekV2DecoderLayer, diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 3ba736c7a..fa294ddcd 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, ) -> Union[Tuple, CausalLMOutputWithPast]: - from sglang.srt.model_executor.graph_runner import get_is_capture_mode + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = ( self._batch_image_inputs(forward_batch) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index a73d8764a..042159a50 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -22,8 +22,8 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_executor.graph_runner import get_is_capture_mode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2Model diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 26971c119..fcb45b947 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -52,8 +52,8 @@ from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_executor.graph_runner import get_is_capture_mode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 984008f48..e824fb1ae 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -6,22 +6,20 @@ from typing import TYPE_CHECKING, Callable import torch from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len - -# TODO(iforgetmyname): Renaming on the way -from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner -from sglang.srt.model_executor.forward_batch_info import ( - CaptureHiddenMode, - ForwardBatch, - ForwardMode, -) -from sglang.srt.model_executor.graph_runner import ( - GRAPH_CAPTURE_FAILED_MSG, +from sglang.srt.model_executor.cuda_graph_runner import ( + CUDA_GRAPH_CAPTURE_FAILED_MSG, + CudaGraphRunner, get_batch_sizes_to_capture, get_global_graph_memory_pool, model_capture_mode, set_global_graph_memory_pool, set_torch_compile_config, ) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.utils import ( require_attn_tp_gather, @@ -123,7 +121,7 @@ class EAGLEDraftCudaGraphRunner: self.capture() except RuntimeError as e: raise Exception( - f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}" + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) def can_run(self, forward_batch: ForwardBatch): diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index a52aea78d..4f4403fee 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -6,16 +6,9 @@ from typing import TYPE_CHECKING, Callable import torch from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len - -# TODO(iforgetmyname): Renaming on the way -from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner -from sglang.srt.model_executor.forward_batch_info import ( - CaptureHiddenMode, - ForwardBatch, - ForwardMode, -) -from sglang.srt.model_executor.graph_runner import ( - GRAPH_CAPTURE_FAILED_MSG, +from sglang.srt.model_executor.cuda_graph_runner import ( + CUDA_GRAPH_CAPTURE_FAILED_MSG, + CudaGraphRunner, LogitsProcessorOutput, get_batch_sizes_to_capture, get_global_graph_memory_pool, @@ -23,6 +16,11 @@ from sglang.srt.model_executor.graph_runner import ( set_global_graph_memory_pool, set_torch_compile_config, ) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk from sglang.srt.utils import ( require_attn_tp_gather, @@ -151,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner: self.capture() except RuntimeError as e: raise Exception( - f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}" + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) def can_run(self, forward_batch: ForwardBatch): diff --git a/test/srt/ascend/test_ascend_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_graph_tp1_bf16.py deleted file mode 100644 index 95c6b7bcf..000000000 --- a/test/srt/ascend/test_ascend_graph_tp1_bf16.py +++ /dev/null @@ -1,95 +0,0 @@ -import unittest -from types import SimpleNamespace -from urllib.parse import urlparse - -from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - is_in_ci, - popen_launch_server, - run_bench_offline_throughput, -) - -TEST_MODEL_MATRIX = { - "Qwen/Qwen2.5-7B-Instruct": { - "accuracy": 0.85, - "latency": 150, - "output_throughput": 30, - }, -} - - -class TestAscendGraphTp1Bf16(CustomTestCase): - - @classmethod - def setUpClass(cls): - cls.models = TEST_MODEL_MATRIX.keys() - cls.base_url = DEFAULT_URL_FOR_TEST - cls.url = urlparse(DEFAULT_URL_FOR_TEST) - cls.common_args = [ - "--trust-remote-code", - "--mem-fraction-static", - 0.8, - "--attention-backend", - "ascend", - ] - - def test_a_gsm8k(self): - for model in self.models: - with self.subTest(model=model): - print(f"##=== Testing accuracy: {model} ===##") - - process = popen_launch_server( - model, - self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - *self.common_args, - ], - ) - - try: - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=128, - host=f"http://{self.url.hostname}", - port=int(self.url.port), - ) - - metrics = run_eval_few_shot_gsm8k(args) - self.assertGreaterEqual( - metrics["accuracy"], - TEST_MODEL_MATRIX[model]["accuracy"], - ) - finally: - kill_process_tree(process.pid) - - def test_b_throughput(self): - for model in self.models: - with self.subTest(model=model): - print(f"##=== Testing throughput: {model} ===##") - - output_throughput = run_bench_offline_throughput( - model, - [ - *self.common_args, - ], - ) - - print(f"##=== {model} throughput: {output_throughput} ===##") - - if is_in_ci(): - self.assertGreater( - output_throughput, - TEST_MODEL_MATRIX[model]["output_throughput"], - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/ascend/test_ascend_graph_tp2_bf16.py b/test/srt/ascend/test_ascend_graph_tp2_bf16.py deleted file mode 100644 index f7c3c6537..000000000 --- a/test/srt/ascend/test_ascend_graph_tp2_bf16.py +++ /dev/null @@ -1,97 +0,0 @@ -import unittest -from types import SimpleNamespace -from urllib.parse import urlparse - -from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - is_in_ci, - popen_launch_server, - run_bench_offline_throughput, -) - -TEST_MODEL_MATRIX = { - "Qwen/Qwen2.5-7B-Instruct": { - "accuracy": 0.85, - "latency": 180, - "output_throughput": 20, - }, -} - - -class TestAscendGraphTp2Bf16(CustomTestCase): - - @classmethod - def setUpClass(cls): - cls.models = TEST_MODEL_MATRIX.keys() - cls.base_url = DEFAULT_URL_FOR_TEST - cls.url = urlparse(DEFAULT_URL_FOR_TEST) - cls.common_args = [ - "--trust-remote-code", - "--mem-fraction-static", - 0.8, - "--attention-backend", - "ascend", - "--tp-size", - 2, - ] - - def test_a_gsm8k(self): - for model in self.models: - with self.subTest(model=model): - print(f"##=== Testing accuracy: {model} ===##") - - process = popen_launch_server( - model, - self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - *self.common_args, - ], - ) - - try: - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=128, - host=f"http://{self.url.hostname}", - port=int(self.url.port), - ) - - metrics = run_eval_few_shot_gsm8k(args) - self.assertGreaterEqual( - metrics["accuracy"], - TEST_MODEL_MATRIX[model]["accuracy"], - ) - finally: - kill_process_tree(process.pid) - - def test_b_throughput(self): - for model in self.models: - with self.subTest(model=model): - print(f"##=== Testing throughput: {model} ===##") - - output_throughput = run_bench_offline_throughput( - model, - [ - *self.common_args, - ], - ) - - print(f"##=== {model} throughput: {output_throughput} ===##") - - if is_in_ci(): - self.assertGreater( - output_throughput, - TEST_MODEL_MATRIX[model]["output_throughput"], - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 4c98dc585..b948bc82e 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -269,11 +269,9 @@ suite_xeon = { suite_ascend = { "per-commit-1-ascend-npu": [ TestFile("ascend/test_ascend_tp1_bf16.py", 400), - TestFile("ascend/test_ascend_graph_tp1_bf16.py", 400), ], "per-commit-2-ascend-npu": [ TestFile("ascend/test_ascend_tp2_bf16.py", 400), - TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400), ], "per-commit-4-ascend-npu": [ TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),