Revert "[feature] Rework Ascend NPU graph support" (#9385)
This commit is contained in:
@@ -55,7 +55,7 @@ _is_npu = is_npu()
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
|
||||
stream: torch.cuda.Stream
|
||||
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
@@ -252,13 +252,9 @@ class GroupCoordinator:
|
||||
|
||||
if is_cuda_alike():
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
elif _is_npu:
|
||||
self.device = torch.device(f"npu:{local_rank}")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_pymscclpp = use_pymscclpp
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
@@ -406,7 +402,7 @@ class GroupCoordinator:
|
||||
self, graph_capture_context: Optional[GraphCaptureContext] = None
|
||||
):
|
||||
if graph_capture_context is None:
|
||||
stream = self.device_module.Stream()
|
||||
stream = torch.cuda.Stream()
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
@@ -417,11 +413,11 @@ class GroupCoordinator:
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
curr_stream = self.device_module.current_stream()
|
||||
curr_stream = torch.cuda.current_stream()
|
||||
if curr_stream != stream:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with self.device_module.stream(stream), maybe_ca_context:
|
||||
with torch.cuda.stream(stream), maybe_ca_context:
|
||||
# In graph mode, we have to be very careful about the collective
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
@@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
)
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -27,7 +27,6 @@ class ForwardMetadata:
|
||||
# seq len inputs
|
||||
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
seq_lens_cpu_list: Optional[List[int]] = None
|
||||
|
||||
|
||||
class AscendAttnBackend(AttentionBackend):
|
||||
@@ -52,7 +51,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
self.forward_metadata = None
|
||||
self.forward_metadata = ForwardMetadata()
|
||||
self.device = model_runner.device
|
||||
self.gen_attention_mask(128, model_runner.dtype)
|
||||
self.page_size = model_runner.page_size
|
||||
@@ -61,15 +60,9 @@ class AscendAttnBackend(AttentionBackend):
|
||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||
self.native_attn = TorchNativeAttnBackend(model_runner)
|
||||
self.graph_metadata = {}
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.graph_mode = False
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
self.forward_metadata = ForwardMetadata()
|
||||
|
||||
self.forward_metadata.block_tables = (
|
||||
forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
|
||||
@@ -82,63 +75,6 @@ class AscendAttnBackend(AttentionBackend):
|
||||
)
|
||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||
|
||||
self.graph_mode = False
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.graph_metadata = {
|
||||
"block_tables": torch.empty(
|
||||
(max_bs, self.max_context_len // self.page_size),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
metadata = ForwardMetadata()
|
||||
|
||||
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
||||
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
||||
|
||||
self.graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
|
||||
self.graph_mode = True
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
metadata = self.graph_metadata[bs]
|
||||
max_len = seq_lens_cpu[:bs].max().item()
|
||||
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
||||
|
||||
metadata.block_tables[:bs, :max_seq_pages].copy_(
|
||||
self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size]
|
||||
// self.page_size
|
||||
)
|
||||
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
||||
metadata.block_tables[bs:, :].fill_(0)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
self.graph_mode = True
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
@@ -231,74 +167,28 @@ class AscendAttnBackend(AttentionBackend):
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
if not self.use_mla:
|
||||
if self.graph_mode:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
layer.layer_id
|
||||
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||
layer.layer_id
|
||||
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
||||
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
workspace = (
|
||||
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
query,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSH",
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||
)
|
||||
)
|
||||
output = torch.empty(
|
||||
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSH",
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||
workspace=workspace,
|
||||
out=[output, softmax_lse],
|
||||
)
|
||||
else:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
out=output,
|
||||
)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
out=output,
|
||||
)
|
||||
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
@@ -376,7 +376,7 @@ class MHATokenToKVPool(KVCache):
|
||||
v_scale: Optional[float] = None,
|
||||
layer_id_override: Optional[int] = None,
|
||||
):
|
||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
if layer_id_override is not None:
|
||||
layer_id = layer_id_override
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Run the model with device graph and torch.compile."""
|
||||
"""Run the model with cuda graph and torch.compile."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -221,7 +221,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
return capture_bs, compile_bs
|
||||
|
||||
|
||||
# Reuse this memory pool across all device graph runners.
|
||||
# Reuse this memory pool across all cuda graph runners.
|
||||
global_graph_memory_pool = None
|
||||
|
||||
|
||||
@@ -234,14 +234,12 @@ def set_global_graph_memory_pool(val):
|
||||
global_graph_memory_pool = val
|
||||
|
||||
|
||||
class GraphRunner:
|
||||
"""A GraphRunner is a base class to run the forward pass of a model with device graph and torch.compile."""
|
||||
class CudaGraphRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Parse args
|
||||
self.model_runner = model_runner
|
||||
self.device = model_runner.device
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
self.graphs = {}
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
@@ -267,7 +265,7 @@ class GraphRunner:
|
||||
|
||||
# Batch sizes to capture
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
rank0_log(f"Capture graph bs {self.capture_bs}")
|
||||
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
||||
self.capture_forward_mode = ForwardMode.DECODE
|
||||
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||
self.num_tokens_per_bs = 1
|
||||
@@ -307,15 +305,13 @@ class GraphRunner:
|
||||
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
||||
|
||||
# Graph inputs
|
||||
with torch.device(self.device):
|
||||
with torch.device("cuda"):
|
||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
self.seq_lens = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
self.out_cache_loc = torch.zeros(
|
||||
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
||||
)
|
||||
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
||||
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
||||
@@ -370,12 +366,12 @@ class GraphRunner:
|
||||
* self.num_tokens_per_bs
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
device="cuda",
|
||||
)
|
||||
self.next_token_logits_buffer = torch.zeros(
|
||||
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
||||
dtype=torch.float,
|
||||
device=self.device,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Capture
|
||||
@@ -384,12 +380,9 @@ class GraphRunner:
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture device graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
def _cache_loc_dtype(self):
|
||||
return torch.int64
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.require_mlp_tp_gather:
|
||||
cuda_graph_bs = (
|
||||
@@ -509,16 +502,8 @@ class GraphRunner:
|
||||
)
|
||||
logger.info(log_message)
|
||||
|
||||
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
||||
with self.device_module.graph(graph, pool=pool, stream=stream):
|
||||
out = run_once_fn()
|
||||
return out
|
||||
|
||||
def _create_device_graph(self):
|
||||
pass
|
||||
|
||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||
graph = self._create_device_graph()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
stream = self.stream
|
||||
num_tokens = bs * self.num_tokens_per_bs
|
||||
|
||||
@@ -658,17 +643,19 @@ class GraphRunner:
|
||||
return logits_output_or_pp_proxy_tensors
|
||||
|
||||
for _ in range(2):
|
||||
self.device_module.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
run_once()
|
||||
|
||||
if get_global_graph_memory_pool() is None:
|
||||
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
|
||||
set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
|
||||
# Set graph pool id globally to be able to use symmetric memory
|
||||
set_graph_pool_id(get_global_graph_memory_pool())
|
||||
out = self._capture_graph(
|
||||
graph, get_global_graph_memory_pool(), stream, run_once
|
||||
)
|
||||
with torch.cuda.graph(
|
||||
graph, pool=get_global_graph_memory_pool(), stream=stream
|
||||
):
|
||||
out = run_once()
|
||||
|
||||
return graph, out
|
||||
|
||||
@@ -850,7 +837,7 @@ class GraphRunner:
|
||||
return spec_info
|
||||
|
||||
|
||||
GRAPH_CAPTURE_FAILED_MSG = (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG = (
|
||||
"Possible solutions:\n"
|
||||
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
||||
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Run the model with cuda graph and torch.compile."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.model_executor.graph_runner import GraphRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
class CudaGraphRunner(GraphRunner):
|
||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Parse args
|
||||
super().__init__(model_runner)
|
||||
|
||||
def _create_device_graph(self):
|
||||
return torch.cuda.CUDAGraph()
|
||||
@@ -89,11 +89,8 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
ReqToTokenPool,
|
||||
SWAKVPool,
|
||||
)
|
||||
|
||||
# TODO(iforgetmyname): Renaming on the way
|
||||
from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
@@ -344,12 +341,9 @@ class ModelRunner:
|
||||
if self.device == "cuda":
|
||||
self.init_cublas()
|
||||
self.init_attention_backend()
|
||||
self.init_device_graphs()
|
||||
elif self.device == "npu":
|
||||
self.init_attention_backend()
|
||||
self.init_device_graphs()
|
||||
self.init_cuda_graphs()
|
||||
else:
|
||||
self.graph_runner = None
|
||||
self.cuda_graph_runner = None
|
||||
self.cuda_graph_mem_usage = 0
|
||||
self.init_attention_backend()
|
||||
|
||||
@@ -923,8 +917,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
# We need to get device after patch otherwise the device would be wrong
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
infered_device = self.device_module.current_device()
|
||||
infered_device = torch.cuda.current_device()
|
||||
|
||||
named_tensors = [
|
||||
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
||||
@@ -1592,9 +1585,9 @@ class ModelRunner:
|
||||
.cuda()
|
||||
)
|
||||
|
||||
def init_device_graphs(self):
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
self.graph_runner = None
|
||||
self.cuda_graph_runner = None
|
||||
self.cuda_graph_mem_usage = 0
|
||||
|
||||
if not self.is_generation:
|
||||
@@ -1609,9 +1602,8 @@ class ModelRunner:
|
||||
logger.info(
|
||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.graph_runner = (
|
||||
CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
self.cuda_graph_mem_usage = before_mem - after_mem
|
||||
logger.info(
|
||||
@@ -1763,11 +1755,11 @@ class ModelRunner:
|
||||
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||
can_run_cuda_graph = bool(
|
||||
forward_batch.forward_mode.is_cuda_graph()
|
||||
and self.graph_runner
|
||||
and self.graph_runner.can_run(forward_batch)
|
||||
and self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(forward_batch)
|
||||
)
|
||||
if can_run_cuda_graph:
|
||||
ret = self.graph_runner.replay(
|
||||
ret = self.cuda_graph_runner.replay(
|
||||
forward_batch,
|
||||
skip_attn_backend_init=skip_attn_backend_init,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Run the model with npu graph and torch.compile."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.model_executor.graph_runner import GraphRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
|
||||
|
||||
class NPUGraphRunner(GraphRunner):
|
||||
"""A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__(model_runner)
|
||||
|
||||
def _create_device_graph(self):
|
||||
return torch.npu.NPUGraph()
|
||||
|
||||
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
||||
with torch.npu.graph(
|
||||
graph,
|
||||
pool=pool,
|
||||
stream=stream,
|
||||
auto_dispatch_capture=True,
|
||||
):
|
||||
out = run_once_fn()
|
||||
return out
|
||||
|
||||
def _update_inputs(self, seq_lens):
|
||||
self.graphs[self.bs].update(
|
||||
cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}]
|
||||
)
|
||||
|
||||
def _cache_loc_dtype(self):
|
||||
return torch.int32
|
||||
|
||||
def replay(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
skip_attn_backend_init: bool = False,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
||||
if not skip_attn_backend_init:
|
||||
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
||||
else:
|
||||
# In speculative decoding, these two fields are still needed.
|
||||
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
||||
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
||||
|
||||
# Replay
|
||||
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
|
||||
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
||||
thread.start()
|
||||
self.graphs[self.bs].replay()
|
||||
thread.join()
|
||||
|
||||
output = self.output_buffers[self.bs]
|
||||
if isinstance(output, LogitsProcessorOutput):
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
||||
hidden_states=(
|
||||
output.hidden_states[: self.raw_num_token]
|
||||
if output.hidden_states is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert isinstance(output, PPProxyTensors)
|
||||
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
||||
@@ -1200,7 +1200,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
|
||||
|
||||
@@ -68,8 +68,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import (
|
||||
DeepseekV2DecoderLayer,
|
||||
|
||||
@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
|
||||
self._batch_image_inputs(forward_batch)
|
||||
|
||||
@@ -22,8 +22,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
|
||||
@@ -52,8 +52,8 @@ from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||
|
||||
@@ -6,22 +6,20 @@ from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
|
||||
# TODO(iforgetmyname): Renaming on the way
|
||||
from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.model_executor.graph_runner import (
|
||||
GRAPH_CAPTURE_FAILED_MSG,
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
model_capture_mode,
|
||||
set_global_graph_memory_pool,
|
||||
set_torch_compile_config,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
@@ -123,7 +121,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
|
||||
@@ -6,16 +6,9 @@ from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
|
||||
# TODO(iforgetmyname): Renaming on the way
|
||||
from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.model_executor.graph_runner import (
|
||||
GRAPH_CAPTURE_FAILED_MSG,
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
LogitsProcessorOutput,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
@@ -23,6 +16,11 @@ from sglang.srt.model_executor.graph_runner import (
|
||||
set_global_graph_memory_pool,
|
||||
set_torch_compile_config,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
@@ -151,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
|
||||
Reference in New Issue
Block a user