@@ -9,7 +9,7 @@ from transformers import AutoConfig
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
fused_moe as fused_moe_triton,
|
fused_moe as fused_moe_triton,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.graph_runner import set_torch_compile_config
|
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(model_name: str, tp_size: int):
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ _is_npu = is_npu()
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphCaptureContext:
|
class GraphCaptureContext:
|
||||||
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
|
stream: torch.cuda.Stream
|
||||||
|
|
||||||
|
|
||||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||||
@@ -252,13 +252,9 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
if is_cuda_alike():
|
if is_cuda_alike():
|
||||||
self.device = torch.device(f"cuda:{local_rank}")
|
self.device = torch.device(f"cuda:{local_rank}")
|
||||||
elif _is_npu:
|
|
||||||
self.device = torch.device(f"npu:{local_rank}")
|
|
||||||
else:
|
else:
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
self.device_module = torch.get_device_module(self.device)
|
|
||||||
|
|
||||||
self.use_pynccl = use_pynccl
|
self.use_pynccl = use_pynccl
|
||||||
self.use_pymscclpp = use_pymscclpp
|
self.use_pymscclpp = use_pymscclpp
|
||||||
self.use_custom_allreduce = use_custom_allreduce
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
@@ -406,7 +402,7 @@ class GroupCoordinator:
|
|||||||
self, graph_capture_context: Optional[GraphCaptureContext] = None
|
self, graph_capture_context: Optional[GraphCaptureContext] = None
|
||||||
):
|
):
|
||||||
if graph_capture_context is None:
|
if graph_capture_context is None:
|
||||||
stream = self.device_module.Stream()
|
stream = torch.cuda.Stream()
|
||||||
graph_capture_context = GraphCaptureContext(stream)
|
graph_capture_context = GraphCaptureContext(stream)
|
||||||
else:
|
else:
|
||||||
stream = graph_capture_context.stream
|
stream = graph_capture_context.stream
|
||||||
@@ -417,11 +413,11 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
# ensure all initialization operations complete before attempting to
|
# ensure all initialization operations complete before attempting to
|
||||||
# capture the graph on another stream
|
# capture the graph on another stream
|
||||||
curr_stream = self.device_module.current_stream()
|
curr_stream = torch.cuda.current_stream()
|
||||||
if curr_stream != stream:
|
if curr_stream != stream:
|
||||||
stream.wait_stream(curr_stream)
|
stream.wait_stream(curr_stream)
|
||||||
|
|
||||||
with self.device_module.stream(stream), maybe_ca_context:
|
with torch.cuda.stream(stream), maybe_ca_context:
|
||||||
# In graph mode, we have to be very careful about the collective
|
# In graph mode, we have to be very careful about the collective
|
||||||
# operations. The current status is:
|
# operations. The current status is:
|
||||||
# allreduce \ Mode | Eager | Graph |
|
# allreduce \ Mode | Eager | Graph |
|
||||||
@@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
|||||||
)
|
)
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
|
||||||
torch.npu.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
@@ -27,7 +27,6 @@ class ForwardMetadata:
|
|||||||
# seq len inputs
|
# seq len inputs
|
||||||
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||||
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||||
seq_lens_cpu_list: Optional[List[int]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AscendAttnBackend(AttentionBackend):
|
class AscendAttnBackend(AttentionBackend):
|
||||||
@@ -52,7 +51,7 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
def __init__(self, model_runner: ModelRunner):
|
def __init__(self, model_runner: ModelRunner):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.forward_metadata = None
|
self.forward_metadata = ForwardMetadata()
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
self.gen_attention_mask(128, model_runner.dtype)
|
self.gen_attention_mask(128, model_runner.dtype)
|
||||||
self.page_size = model_runner.page_size
|
self.page_size = model_runner.page_size
|
||||||
@@ -61,15 +60,9 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||||
self.native_attn = TorchNativeAttnBackend(model_runner)
|
self.native_attn = TorchNativeAttnBackend(model_runner)
|
||||||
self.graph_metadata = {}
|
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
||||||
self.graph_mode = False
|
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Init the metadata for a forward pass."""
|
"""Init the metadata for a forward pass."""
|
||||||
self.forward_metadata = ForwardMetadata()
|
|
||||||
|
|
||||||
self.forward_metadata.block_tables = (
|
self.forward_metadata.block_tables = (
|
||||||
forward_batch.req_to_token_pool.req_to_token[
|
forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
|
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
|
||||||
@@ -82,63 +75,6 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||||
|
|
||||||
self.graph_mode = False
|
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
|
||||||
self.graph_metadata = {
|
|
||||||
"block_tables": torch.empty(
|
|
||||||
(max_bs, self.max_context_len // self.page_size),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
|
||||||
self,
|
|
||||||
bs: int,
|
|
||||||
num_tokens: int,
|
|
||||||
req_pool_indices: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
||||||
):
|
|
||||||
metadata = ForwardMetadata()
|
|
||||||
|
|
||||||
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
|
||||||
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
|
||||||
|
|
||||||
self.graph_metadata[bs] = metadata
|
|
||||||
self.forward_metadata = metadata
|
|
||||||
|
|
||||||
self.graph_mode = True
|
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
|
||||||
self,
|
|
||||||
bs: int,
|
|
||||||
req_pool_indices: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
seq_lens_sum: int,
|
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
||||||
seq_lens_cpu: Optional[torch.Tensor],
|
|
||||||
):
|
|
||||||
metadata = self.graph_metadata[bs]
|
|
||||||
max_len = seq_lens_cpu[:bs].max().item()
|
|
||||||
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
|
||||||
|
|
||||||
metadata.block_tables[:bs, :max_seq_pages].copy_(
|
|
||||||
self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size]
|
|
||||||
// self.page_size
|
|
||||||
)
|
|
||||||
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
|
||||||
metadata.block_tables[bs:, :].fill_(0)
|
|
||||||
|
|
||||||
self.forward_metadata = metadata
|
|
||||||
|
|
||||||
self.graph_mode = True
|
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
@@ -231,74 +167,28 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
if self.graph_mode:
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||||
layer.layer_id
|
|
||||||
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
|
|
||||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
|
||||||
layer.layer_id
|
|
||||||
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
|
||||||
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
|
||||||
num_tokens = query.shape[0]
|
|
||||||
workspace = (
|
|
||||||
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
|
||||||
query,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
block_table=self.forward_metadata.block_tables,
|
|
||||||
block_size=self.page_size,
|
|
||||||
num_heads=layer.tp_q_head_num,
|
|
||||||
num_key_value_heads=layer.tp_k_head_num,
|
|
||||||
input_layout="BSH",
|
|
||||||
scale=layer.scaling,
|
|
||||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
output = torch.empty(
|
|
||||||
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
|
||||||
dtype=q.dtype,
|
|
||||||
device=q.device,
|
|
||||||
)
|
|
||||||
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
|
||||||
torch_npu.npu_fused_infer_attention_score.out(
|
|
||||||
query,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
block_table=self.forward_metadata.block_tables,
|
|
||||||
block_size=self.page_size,
|
|
||||||
num_heads=layer.tp_q_head_num,
|
|
||||||
num_key_value_heads=layer.tp_k_head_num,
|
|
||||||
input_layout="BSH",
|
|
||||||
scale=layer.scaling,
|
|
||||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
|
||||||
workspace=workspace,
|
|
||||||
out=[output, softmax_lse],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
|
||||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
|
||||||
layer.layer_id
|
|
||||||
)
|
|
||||||
|
|
||||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
device=query.device,
|
device=query.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_npu._npu_paged_attention(
|
torch_npu._npu_paged_attention(
|
||||||
query=query,
|
query=query,
|
||||||
key_cache=k_cache,
|
key_cache=k_cache,
|
||||||
value_cache=v_cache,
|
value_cache=v_cache,
|
||||||
num_heads=layer.tp_q_head_num,
|
num_heads=layer.tp_q_head_num,
|
||||||
num_kv_heads=layer.tp_k_head_num,
|
num_kv_heads=layer.tp_k_head_num,
|
||||||
scale_value=layer.scaling,
|
scale_value=layer.scaling,
|
||||||
block_table=self.forward_metadata.block_tables,
|
block_table=self.forward_metadata.block_tables,
|
||||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||||
out=output,
|
out=output,
|
||||||
)
|
)
|
||||||
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
else:
|
else:
|
||||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
@@ -330,6 +220,3 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
out=attn_output,
|
out=attn_output,
|
||||||
)
|
)
|
||||||
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
|
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
|
||||||
return 0
|
|
||||||
|
|||||||
@@ -376,7 +376,7 @@ class MHATokenToKVPool(KVCache):
|
|||||||
v_scale: Optional[float] = None,
|
v_scale: Optional[float] = None,
|
||||||
layer_id_override: Optional[int] = None,
|
layer_id_override: Optional[int] = None,
|
||||||
):
|
):
|
||||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
|
||||||
if layer_id_override is not None:
|
if layer_id_override is not None:
|
||||||
layer_id = layer_id_override
|
layer_id = layer_id_override
|
||||||
|
|||||||
@@ -15,22 +15,833 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
import bisect
|
||||||
|
import gc
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from torch.profiler import ProfilerActivity, profile
|
||||||
|
|
||||||
from sglang.srt.model_executor.graph_runner import GraphRunner
|
from sglang.srt.custom_op import CustomOp
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
set_graph_pool_id,
|
||||||
|
)
|
||||||
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
DpPaddingMode,
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
set_dp_buffer_len,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
CaptureHiddenMode,
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
PPProxyTensors,
|
||||||
|
enable_num_token_non_padded,
|
||||||
|
)
|
||||||
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||||
|
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
empty_context,
|
||||||
|
get_available_gpu_memory,
|
||||||
|
get_device_memory_capacity,
|
||||||
|
rank0_log,
|
||||||
|
require_attn_tp_gather,
|
||||||
|
require_gathered_buffer,
|
||||||
|
require_mlp_sync,
|
||||||
|
require_mlp_tp_gather,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
# Detect whether the current forward pass is in capture mode
|
||||||
|
is_capture_mode = False
|
||||||
|
|
||||||
class CudaGraphRunner(GraphRunner):
|
|
||||||
|
def get_is_capture_mode():
|
||||||
|
return is_capture_mode
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def model_capture_mode():
|
||||||
|
global is_capture_mode
|
||||||
|
is_capture_mode = True
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
is_capture_mode = False
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def freeze_gc(enable_cudagraph_gc: bool):
|
||||||
|
"""
|
||||||
|
Optimize garbage collection during CUDA graph capture.
|
||||||
|
Clean up, then freeze all remaining objects from being included
|
||||||
|
in future collections if GC is disabled during capture.
|
||||||
|
"""
|
||||||
|
gc.collect()
|
||||||
|
should_freeze = not enable_cudagraph_gc
|
||||||
|
if should_freeze:
|
||||||
|
gc.freeze()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if should_freeze:
|
||||||
|
gc.unfreeze()
|
||||||
|
|
||||||
|
|
||||||
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||||
|
for sub in model._modules.values():
|
||||||
|
if isinstance(sub, CustomOp):
|
||||||
|
if reverse:
|
||||||
|
sub.leave_torch_compile()
|
||||||
|
else:
|
||||||
|
sub.enter_torch_compile(num_tokens=num_tokens)
|
||||||
|
if isinstance(sub, torch.nn.Module):
|
||||||
|
_to_torch(sub, reverse, num_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_model(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
enable_compile: bool,
|
||||||
|
num_tokens: int,
|
||||||
|
tp_group: GroupCoordinator,
|
||||||
|
):
|
||||||
|
"""Patch the model to make it compatible with with torch.compile"""
|
||||||
|
backup_ca_comm = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if enable_compile:
|
||||||
|
_to_torch(model, reverse=False, num_tokens=num_tokens)
|
||||||
|
backup_ca_comm = tp_group.ca_comm
|
||||||
|
# Use custom-allreduce here.
|
||||||
|
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
||||||
|
# even with ENABLE_INTRA_NODE_COMM=1.
|
||||||
|
# tp_group.ca_comm = None
|
||||||
|
yield torch.compile(
|
||||||
|
torch.no_grad()(model.forward),
|
||||||
|
mode=os.environ.get(
|
||||||
|
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
||||||
|
),
|
||||||
|
dynamic=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield model.forward
|
||||||
|
finally:
|
||||||
|
if enable_compile:
|
||||||
|
_to_torch(model, reverse=True, num_tokens=num_tokens)
|
||||||
|
tp_group.ca_comm = backup_ca_comm
|
||||||
|
|
||||||
|
|
||||||
|
def set_torch_compile_config():
|
||||||
|
import torch._dynamo.config
|
||||||
|
import torch._inductor.config
|
||||||
|
|
||||||
|
torch._inductor.config.coordinate_descent_tuning = True
|
||||||
|
torch._inductor.config.triton.unique_kernel_names = True
|
||||||
|
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
||||||
|
|
||||||
|
# FIXME: tmp workaround
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||||
|
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
||||||
|
torch._dynamo.config.cache_size_limit = 1024
|
||||||
|
|
||||||
|
monkey_patch_torch_compile()
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||||
|
server_args = model_runner.server_args
|
||||||
|
capture_bs = server_args.cuda_graph_bs
|
||||||
|
|
||||||
|
if capture_bs is None:
|
||||||
|
if server_args.speculative_algorithm is None:
|
||||||
|
if server_args.disable_cuda_graph_padding:
|
||||||
|
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
|
||||||
|
else:
|
||||||
|
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
||||||
|
else:
|
||||||
|
# Since speculative decoding requires more cuda graph memory, we
|
||||||
|
# capture less.
|
||||||
|
capture_bs = (
|
||||||
|
list(range(1, 9))
|
||||||
|
+ list(range(10, 33, 2))
|
||||||
|
+ list(range(40, 64, 8))
|
||||||
|
+ list(range(80, 161, 16))
|
||||||
|
)
|
||||||
|
|
||||||
|
gpu_mem = get_device_memory_capacity()
|
||||||
|
if gpu_mem is not None:
|
||||||
|
if gpu_mem > 90 * 1024: # H200, H20
|
||||||
|
capture_bs += list(range(160, 257, 8))
|
||||||
|
if gpu_mem > 160 * 1000: # B200, MI300
|
||||||
|
capture_bs += list(range(256, 513, 16))
|
||||||
|
|
||||||
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||||
|
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||||
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
||||||
|
capture_bs += [model_runner.req_to_token_pool.size]
|
||||||
|
|
||||||
|
mul_base = 1
|
||||||
|
|
||||||
|
if server_args.enable_two_batch_overlap:
|
||||||
|
mul_base *= 2
|
||||||
|
|
||||||
|
if require_gathered_buffer(server_args):
|
||||||
|
mul_base *= get_attention_tp_size()
|
||||||
|
|
||||||
|
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
|
||||||
|
|
||||||
|
if server_args.cuda_graph_max_bs:
|
||||||
|
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
||||||
|
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
||||||
|
capture_bs += list(
|
||||||
|
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
|
||||||
|
)
|
||||||
|
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
||||||
|
capture_bs = list(sorted(set(capture_bs)))
|
||||||
|
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
||||||
|
compile_bs = (
|
||||||
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
||||||
|
if server_args.enable_torch_compile
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
return capture_bs, compile_bs
|
||||||
|
|
||||||
|
|
||||||
|
# Reuse this memory pool across all cuda graph runners.
|
||||||
|
global_graph_memory_pool = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_graph_memory_pool():
|
||||||
|
return global_graph_memory_pool
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_graph_memory_pool(val):
|
||||||
|
global global_graph_memory_pool
|
||||||
|
global_graph_memory_pool = val
|
||||||
|
|
||||||
|
|
||||||
|
class CudaGraphRunner:
|
||||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||||
|
|
||||||
def __init__(self, model_runner: ModelRunner):
|
def __init__(self, model_runner: ModelRunner):
|
||||||
# Parse args
|
# Parse args
|
||||||
super().__init__(model_runner)
|
self.model_runner = model_runner
|
||||||
|
self.graphs = {}
|
||||||
|
self.output_buffers = {}
|
||||||
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||||
|
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||||
|
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||||
|
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
||||||
|
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||||
|
self.enable_two_batch_overlap = (
|
||||||
|
model_runner.server_args.enable_two_batch_overlap
|
||||||
|
)
|
||||||
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||||
|
self.enable_profile_cuda_graph = (
|
||||||
|
model_runner.server_args.enable_profile_cuda_graph
|
||||||
|
)
|
||||||
|
self.tp_size = model_runner.server_args.tp_size
|
||||||
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
|
self.pp_size = model_runner.server_args.pp_size
|
||||||
|
|
||||||
def _create_device_graph(self):
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
return torch.cuda.CUDAGraph()
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
|
|
||||||
|
# Batch sizes to capture
|
||||||
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||||
|
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
||||||
|
self.capture_forward_mode = ForwardMode.DECODE
|
||||||
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||||
|
self.num_tokens_per_bs = 1
|
||||||
|
if model_runner.spec_algorithm.is_eagle():
|
||||||
|
if self.model_runner.is_draft_worker:
|
||||||
|
raise RuntimeError("This should not happen")
|
||||||
|
else:
|
||||||
|
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
|
self.num_tokens_per_bs = (
|
||||||
|
self.model_runner.server_args.speculative_num_draft_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
||||||
|
if model_runner.server_args.enable_return_hidden_states:
|
||||||
|
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
|
||||||
|
# Attention backend
|
||||||
|
self.max_bs = max(self.capture_bs)
|
||||||
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||||
|
self.model_runner.attn_backend.init_cuda_graph_state(
|
||||||
|
self.max_bs, self.max_num_token
|
||||||
|
)
|
||||||
|
self.seq_len_fill_value = (
|
||||||
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
|
)
|
||||||
|
|
||||||
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||||
|
self.encoder_len_fill_value = 0
|
||||||
|
self.seq_lens_cpu = torch.full(
|
||||||
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enable_torch_compile:
|
||||||
|
set_torch_compile_config()
|
||||||
|
|
||||||
|
if self.model_runner.server_args.enable_lora:
|
||||||
|
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
||||||
|
|
||||||
|
# Graph inputs
|
||||||
|
with torch.device("cuda"):
|
||||||
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||||
|
self.seq_lens = torch.full(
|
||||||
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
|
)
|
||||||
|
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
||||||
|
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
||||||
|
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
||||||
|
|
||||||
|
# pipeline parallelism
|
||||||
|
if self.pp_size > 1:
|
||||||
|
self.pp_proxy_tensors = {
|
||||||
|
"hidden_states": torch.zeros(
|
||||||
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
"residual": torch.zeros(
|
||||||
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Speculative_inference
|
||||||
|
if model_runner.spec_algorithm.is_eagle3():
|
||||||
|
self.model_runner.model.set_eagle3_layers_to_capture()
|
||||||
|
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
||||||
|
self.encoder_lens = torch.full(
|
||||||
|
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder_lens = None
|
||||||
|
|
||||||
|
if self.require_gathered_buffer:
|
||||||
|
if self.require_mlp_tp_gather:
|
||||||
|
self.global_num_tokens_gpu = torch.zeros(
|
||||||
|
(self.dp_size,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||||
|
(self.dp_size,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.require_attn_tp_gather
|
||||||
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||||
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||||
|
(1,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.global_num_tokens_gpu = None
|
||||||
|
self.global_num_tokens_for_logprob_gpu = None
|
||||||
|
|
||||||
|
self.custom_mask = torch.ones(
|
||||||
|
(
|
||||||
|
(self.seq_lens.sum().item() + self.max_num_token)
|
||||||
|
* self.num_tokens_per_bs
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
self.next_token_logits_buffer = torch.zeros(
|
||||||
|
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture
|
||||||
|
try:
|
||||||
|
with model_capture_mode():
|
||||||
|
self.capture()
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
if self.require_mlp_tp_gather:
|
||||||
|
cuda_graph_bs = (
|
||||||
|
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
else max(forward_batch.global_num_tokens_cpu)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cuda_graph_bs = forward_batch.batch_size
|
||||||
|
|
||||||
|
is_bs_supported = (
|
||||||
|
cuda_graph_bs in self.graphs
|
||||||
|
if self.disable_padding
|
||||||
|
else cuda_graph_bs <= self.max_bs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.require_mlp_sync:
|
||||||
|
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
||||||
|
|
||||||
|
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
||||||
|
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
||||||
|
# because the full_text_row_masked_out_mask tensor will always be ones
|
||||||
|
is_encoder_lens_supported = (
|
||||||
|
torch.all(forward_batch.encoder_lens > 0)
|
||||||
|
if self.is_encoder_decoder
|
||||||
|
else True
|
||||||
|
)
|
||||||
|
|
||||||
|
requested_capture_hidden_mode = max(
|
||||||
|
forward_batch.capture_hidden_mode,
|
||||||
|
(
|
||||||
|
forward_batch.spec_info.capture_hidden_mode
|
||||||
|
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
||||||
|
is not None
|
||||||
|
else CaptureHiddenMode.NULL
|
||||||
|
),
|
||||||
|
)
|
||||||
|
capture_hidden_mode_matches = (
|
||||||
|
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
||||||
|
or requested_capture_hidden_mode == self.capture_hidden_mode
|
||||||
|
)
|
||||||
|
is_tbo_supported = (
|
||||||
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
is_bs_supported
|
||||||
|
and is_encoder_lens_supported
|
||||||
|
and is_tbo_supported
|
||||||
|
and capture_hidden_mode_matches
|
||||||
|
)
|
||||||
|
|
||||||
|
def capture(self) -> None:
|
||||||
|
profile_context = empty_context()
|
||||||
|
if self.enable_profile_cuda_graph:
|
||||||
|
profile_context = profile(
|
||||||
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
|
record_shapes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger CUDA graph capture for specific shapes.
|
||||||
|
# Capture the large shapes first so that the smaller shapes
|
||||||
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
|
with freeze_gc(
|
||||||
|
self.model_runner.server_args.enable_cudagraph_gc
|
||||||
|
), graph_capture() as graph_capture_context:
|
||||||
|
with profile_context as prof:
|
||||||
|
self.stream = graph_capture_context.stream
|
||||||
|
avail_mem = get_available_gpu_memory(
|
||||||
|
self.model_runner.device,
|
||||||
|
self.model_runner.gpu_id,
|
||||||
|
empty_cache=False,
|
||||||
|
)
|
||||||
|
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||||
|
capture_range = (
|
||||||
|
tqdm.tqdm(list(reversed(self.capture_bs)))
|
||||||
|
if get_tensor_model_parallel_rank() == 0
|
||||||
|
else reversed(self.capture_bs)
|
||||||
|
)
|
||||||
|
for i, bs in enumerate(capture_range):
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
avail_mem = get_available_gpu_memory(
|
||||||
|
self.model_runner.device,
|
||||||
|
self.model_runner.gpu_id,
|
||||||
|
empty_cache=False,
|
||||||
|
)
|
||||||
|
capture_range.set_description(
|
||||||
|
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch_model(
|
||||||
|
self.model_runner.model,
|
||||||
|
bs in self.compile_bs,
|
||||||
|
num_tokens=bs * self.num_tokens_per_bs,
|
||||||
|
tp_group=self.model_runner.tp_group,
|
||||||
|
) as forward:
|
||||||
|
(
|
||||||
|
graph,
|
||||||
|
output_buffers,
|
||||||
|
) = self.capture_one_batch_size(bs, forward)
|
||||||
|
self.graphs[bs] = graph
|
||||||
|
self.output_buffers[bs] = output_buffers
|
||||||
|
|
||||||
|
# Save gemlite cache after each capture
|
||||||
|
save_gemlite_cache()
|
||||||
|
|
||||||
|
if self.enable_profile_cuda_graph:
|
||||||
|
log_message = (
|
||||||
|
"Sorted by CUDA Time:\n"
|
||||||
|
+ prof.key_averages(group_by_input_shape=True).table(
|
||||||
|
sort_by="cuda_time_total", row_limit=10
|
||||||
|
)
|
||||||
|
+ "\n\nSorted by CPU Time:\n"
|
||||||
|
+ prof.key_averages(group_by_input_shape=True).table(
|
||||||
|
sort_by="cpu_time_total", row_limit=10
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(log_message)
|
||||||
|
|
||||||
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
stream = self.stream
|
||||||
|
num_tokens = bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
|
# Graph inputs
|
||||||
|
input_ids = self.input_ids[:num_tokens]
|
||||||
|
req_pool_indices = self.req_pool_indices[:bs]
|
||||||
|
seq_lens = self.seq_lens[:bs]
|
||||||
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||||
|
positions = self.positions[:num_tokens]
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
encoder_lens = self.encoder_lens[:bs]
|
||||||
|
else:
|
||||||
|
encoder_lens = None
|
||||||
|
mrope_positions = self.mrope_positions[:, :bs]
|
||||||
|
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
||||||
|
self.num_token_non_padded[...] = num_tokens
|
||||||
|
|
||||||
|
# pipeline parallelism
|
||||||
|
if self.pp_size > 1:
|
||||||
|
pp_proxy_tensors = PPProxyTensors(
|
||||||
|
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.require_mlp_tp_gather:
|
||||||
|
self.global_num_tokens_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[num_tokens] * self.dp_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[num_tokens] * self.dp_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
global_dp_buffer_len = num_tokens * self.dp_size
|
||||||
|
elif self.require_attn_tp_gather:
|
||||||
|
self.global_num_tokens_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[num_tokens],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[num_tokens],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
global_dp_buffer_len = num_tokens
|
||||||
|
else:
|
||||||
|
global_dp_buffer_len = None
|
||||||
|
|
||||||
|
spec_info = self.get_spec_info(num_tokens)
|
||||||
|
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
||||||
|
self.capture_hidden_mode = (
|
||||||
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_runner.server_args.enable_lora:
|
||||||
|
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
||||||
|
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
||||||
|
lora_ids = [None] * bs
|
||||||
|
else:
|
||||||
|
lora_ids = None
|
||||||
|
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
forward_mode=self.capture_forward_mode,
|
||||||
|
batch_size=bs,
|
||||||
|
input_ids=input_ids,
|
||||||
|
req_pool_indices=req_pool_indices,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
next_token_logits_buffer=next_token_logits_buffer,
|
||||||
|
orig_seq_lens=seq_lens,
|
||||||
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
|
attn_backend=self.model_runner.attn_backend,
|
||||||
|
out_cache_loc=out_cache_loc,
|
||||||
|
seq_lens_sum=seq_lens.sum().item(),
|
||||||
|
encoder_lens=encoder_lens,
|
||||||
|
return_logprob=False,
|
||||||
|
positions=positions,
|
||||||
|
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
||||||
|
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
||||||
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||||
|
global_dp_buffer_len=global_dp_buffer_len,
|
||||||
|
mrope_positions=mrope_positions,
|
||||||
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
|
spec_info=spec_info,
|
||||||
|
capture_hidden_mode=self.capture_hidden_mode,
|
||||||
|
num_token_non_padded=self.num_token_non_padded,
|
||||||
|
global_forward_mode=self.capture_forward_mode,
|
||||||
|
lora_ids=lora_ids,
|
||||||
|
)
|
||||||
|
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
||||||
|
|
||||||
|
if lora_ids is not None:
|
||||||
|
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
||||||
|
|
||||||
|
# Attention backend
|
||||||
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
|
bs,
|
||||||
|
num_tokens,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
encoder_lens,
|
||||||
|
forward_batch.forward_mode,
|
||||||
|
forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run and capture
|
||||||
|
def run_once():
|
||||||
|
# Clean intermediate result cache for DP attention
|
||||||
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if (
|
||||||
|
self.pp_size > 1
|
||||||
|
and "pp_proxy_tensors" in inspect.signature(forward).parameters
|
||||||
|
):
|
||||||
|
kwargs["pp_proxy_tensors"] = PPProxyTensors(
|
||||||
|
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_output_or_pp_proxy_tensors = forward(
|
||||||
|
input_ids,
|
||||||
|
forward_batch.positions,
|
||||||
|
forward_batch,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return logits_output_or_pp_proxy_tensors
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
|
run_once()
|
||||||
|
|
||||||
|
if get_global_graph_memory_pool() is None:
|
||||||
|
set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
|
||||||
|
# Set graph pool id globally to be able to use symmetric memory
|
||||||
|
set_graph_pool_id(get_global_graph_memory_pool())
|
||||||
|
with torch.cuda.graph(
|
||||||
|
graph, pool=get_global_graph_memory_pool(), stream=stream
|
||||||
|
):
|
||||||
|
out = run_once()
|
||||||
|
|
||||||
|
return graph, out
|
||||||
|
|
||||||
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||||
|
|
||||||
|
# If the required capture_hidden_mode changes, we need to recapture the graph
|
||||||
|
|
||||||
|
# These are the different factors that can influence the capture_hidden_mode
|
||||||
|
capture_hidden_mode_required_by_forward_batch = (
|
||||||
|
forward_batch.capture_hidden_mode
|
||||||
|
)
|
||||||
|
capture_hidden_mode_required_by_spec_info = getattr(
|
||||||
|
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||||
|
)
|
||||||
|
capture_hidden_mode_required_for_returning_hidden_states = (
|
||||||
|
CaptureHiddenMode.FULL
|
||||||
|
if self.model_runner.server_args.enable_return_hidden_states
|
||||||
|
else CaptureHiddenMode.NULL
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine the highest capture_hidden_mode required
|
||||||
|
# (If we have FULL, we can emulate LAST or NULL)
|
||||||
|
# (If we have LAST, we can emulate NULL)
|
||||||
|
required_capture_hidden_mode = max(
|
||||||
|
capture_hidden_mode_required_by_forward_batch,
|
||||||
|
capture_hidden_mode_required_by_spec_info,
|
||||||
|
capture_hidden_mode_required_for_returning_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
||||||
|
if self.capture_hidden_mode != required_capture_hidden_mode:
|
||||||
|
self.capture_hidden_mode = required_capture_hidden_mode
|
||||||
|
self.capture()
|
||||||
|
|
||||||
|
def replay_prepare(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
):
|
||||||
|
self.recapture_if_needed(forward_batch)
|
||||||
|
|
||||||
|
raw_bs = forward_batch.batch_size
|
||||||
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
|
# Pad
|
||||||
|
if self.require_mlp_tp_gather:
|
||||||
|
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
||||||
|
max_batch_size = (
|
||||||
|
max_num_tokens / self.num_tokens_per_bs
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
else max_num_tokens
|
||||||
|
)
|
||||||
|
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
||||||
|
else:
|
||||||
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
|
bs = self.capture_bs[index]
|
||||||
|
if bs != raw_bs:
|
||||||
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
|
# Common inputs
|
||||||
|
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
||||||
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||||
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||||
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||||
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||||
|
|
||||||
|
seq_lens_cpu = None
|
||||||
|
if forward_batch.seq_lens_cpu is not None:
|
||||||
|
if bs != raw_bs:
|
||||||
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||||
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||||
|
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||||
|
|
||||||
|
if pp_proxy_tensors:
|
||||||
|
for key in self.pp_proxy_tensors.keys():
|
||||||
|
dim = pp_proxy_tensors[key].shape[0]
|
||||||
|
self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
|
||||||
|
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
|
if forward_batch.mrope_positions is not None:
|
||||||
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||||
|
if self.require_gathered_buffer:
|
||||||
|
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||||
|
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||||
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
||||||
|
num_token_non_padded = forward_batch.num_token_non_padded
|
||||||
|
if self.require_gathered_buffer:
|
||||||
|
tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
|
||||||
|
num_local_token_non_padded = torch.clamp(
|
||||||
|
num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
|
||||||
|
min=0,
|
||||||
|
max=tokens_per_rank,
|
||||||
|
)
|
||||||
|
self.num_token_non_padded.copy_(num_local_token_non_padded)
|
||||||
|
else:
|
||||||
|
self.num_token_non_padded.copy_(num_token_non_padded)
|
||||||
|
if self.enable_two_batch_overlap:
|
||||||
|
self.tbo_plugin.replay_prepare(
|
||||||
|
forward_mode=self.capture_forward_mode,
|
||||||
|
bs=bs,
|
||||||
|
num_token_non_padded=len(forward_batch.input_ids),
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
||||||
|
forward_batch.spec_info.custom_mask = self.custom_mask
|
||||||
|
# Attention backend
|
||||||
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
|
bs,
|
||||||
|
self.req_pool_indices[:bs],
|
||||||
|
self.seq_lens[:bs],
|
||||||
|
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
||||||
|
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
||||||
|
self.capture_forward_mode,
|
||||||
|
forward_batch.spec_info,
|
||||||
|
seq_lens_cpu=seq_lens_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store fields
|
||||||
|
self.raw_bs = raw_bs
|
||||||
|
self.raw_num_token = raw_num_token
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
def replay(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
skip_attn_backend_init: bool = False,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
||||||
|
if not skip_attn_backend_init:
|
||||||
|
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
||||||
|
else:
|
||||||
|
# In speculative decoding, these two fields are still needed.
|
||||||
|
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
||||||
|
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
||||||
|
|
||||||
|
# Replay
|
||||||
|
self.graphs[self.bs].replay()
|
||||||
|
|
||||||
|
output = self.output_buffers[self.bs]
|
||||||
|
if isinstance(output, LogitsProcessorOutput):
|
||||||
|
return LogitsProcessorOutput(
|
||||||
|
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
||||||
|
hidden_states=(
|
||||||
|
output.hidden_states[: self.raw_num_token]
|
||||||
|
if output.hidden_states is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(output, PPProxyTensors)
|
||||||
|
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
||||||
|
|
||||||
|
def get_spec_info(self, num_tokens: int):
|
||||||
|
spec_info = None
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle():
|
||||||
|
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
||||||
|
|
||||||
|
if self.model_runner.is_draft_worker:
|
||||||
|
raise RuntimeError("This should not happen.")
|
||||||
|
else:
|
||||||
|
spec_info = EagleVerifyInput(
|
||||||
|
draft_token=None,
|
||||||
|
custom_mask=self.custom_mask,
|
||||||
|
positions=None,
|
||||||
|
retrive_index=None,
|
||||||
|
retrive_next_token=None,
|
||||||
|
retrive_next_sibling=None,
|
||||||
|
retrive_cum_len=None,
|
||||||
|
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
||||||
|
topk=self.model_runner.server_args.speculative_eagle_topk,
|
||||||
|
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||||
|
seq_lens_sum=None,
|
||||||
|
seq_lens_cpu=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return spec_info
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_GRAPH_CAPTURE_FAILED_MSG = (
|
||||||
|
"Possible solutions:\n"
|
||||||
|
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
||||||
|
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
||||||
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||||
|
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
||||||
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,860 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Run the model with device graph and torch.compile."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import bisect
|
|
||||||
import gc
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
from torch.profiler import ProfilerActivity, profile
|
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|
||||||
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
||||||
set_graph_pool_id,
|
|
||||||
)
|
|
||||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
|
||||||
from sglang.srt.layers.dp_attention import (
|
|
||||||
DpPaddingMode,
|
|
||||||
get_attention_tp_rank,
|
|
||||||
get_attention_tp_size,
|
|
||||||
set_dp_buffer_len,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
||||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
|
||||||
CaptureHiddenMode,
|
|
||||||
ForwardBatch,
|
|
||||||
ForwardMode,
|
|
||||||
PPProxyTensors,
|
|
||||||
enable_num_token_non_padded,
|
|
||||||
)
|
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
|
||||||
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
|
||||||
from sglang.srt.utils import (
|
|
||||||
empty_context,
|
|
||||||
get_available_gpu_memory,
|
|
||||||
get_device_memory_capacity,
|
|
||||||
rank0_log,
|
|
||||||
require_attn_tp_gather,
|
|
||||||
require_gathered_buffer,
|
|
||||||
require_mlp_sync,
|
|
||||||
require_mlp_tp_gather,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
||||||
|
|
||||||
# Detect whether the current forward pass is in capture mode
|
|
||||||
is_capture_mode = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_is_capture_mode():
|
|
||||||
return is_capture_mode
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def model_capture_mode():
|
|
||||||
global is_capture_mode
|
|
||||||
is_capture_mode = True
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
is_capture_mode = False
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def freeze_gc(enable_cudagraph_gc: bool):
|
|
||||||
"""
|
|
||||||
Optimize garbage collection during CUDA graph capture.
|
|
||||||
Clean up, then freeze all remaining objects from being included
|
|
||||||
in future collections if GC is disabled during capture.
|
|
||||||
"""
|
|
||||||
gc.collect()
|
|
||||||
should_freeze = not enable_cudagraph_gc
|
|
||||||
if should_freeze:
|
|
||||||
gc.freeze()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if should_freeze:
|
|
||||||
gc.unfreeze()
|
|
||||||
|
|
||||||
|
|
||||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
|
||||||
for sub in model._modules.values():
|
|
||||||
if isinstance(sub, CustomOp):
|
|
||||||
if reverse:
|
|
||||||
sub.leave_torch_compile()
|
|
||||||
else:
|
|
||||||
sub.enter_torch_compile(num_tokens=num_tokens)
|
|
||||||
if isinstance(sub, torch.nn.Module):
|
|
||||||
_to_torch(sub, reverse, num_tokens)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def patch_model(
|
|
||||||
model: torch.nn.Module,
|
|
||||||
enable_compile: bool,
|
|
||||||
num_tokens: int,
|
|
||||||
tp_group: GroupCoordinator,
|
|
||||||
):
|
|
||||||
"""Patch the model to make it compatible with with torch.compile"""
|
|
||||||
backup_ca_comm = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
if enable_compile:
|
|
||||||
_to_torch(model, reverse=False, num_tokens=num_tokens)
|
|
||||||
backup_ca_comm = tp_group.ca_comm
|
|
||||||
# Use custom-allreduce here.
|
|
||||||
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
|
||||||
# even with ENABLE_INTRA_NODE_COMM=1.
|
|
||||||
# tp_group.ca_comm = None
|
|
||||||
yield torch.compile(
|
|
||||||
torch.no_grad()(model.forward),
|
|
||||||
mode=os.environ.get(
|
|
||||||
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
|
||||||
),
|
|
||||||
dynamic=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield model.forward
|
|
||||||
finally:
|
|
||||||
if enable_compile:
|
|
||||||
_to_torch(model, reverse=True, num_tokens=num_tokens)
|
|
||||||
tp_group.ca_comm = backup_ca_comm
|
|
||||||
|
|
||||||
|
|
||||||
def set_torch_compile_config():
|
|
||||||
import torch._dynamo.config
|
|
||||||
import torch._inductor.config
|
|
||||||
|
|
||||||
torch._inductor.config.coordinate_descent_tuning = True
|
|
||||||
torch._inductor.config.triton.unique_kernel_names = True
|
|
||||||
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
|
||||||
|
|
||||||
# FIXME: tmp workaround
|
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
|
||||||
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
|
||||||
torch._dynamo.config.cache_size_limit = 1024
|
|
||||||
|
|
||||||
monkey_patch_torch_compile()
|
|
||||||
|
|
||||||
|
|
||||||
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
||||||
server_args = model_runner.server_args
|
|
||||||
capture_bs = server_args.cuda_graph_bs
|
|
||||||
|
|
||||||
if capture_bs is None:
|
|
||||||
if server_args.speculative_algorithm is None:
|
|
||||||
if server_args.disable_cuda_graph_padding:
|
|
||||||
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
|
|
||||||
else:
|
|
||||||
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
|
||||||
else:
|
|
||||||
# Since speculative decoding requires more cuda graph memory, we
|
|
||||||
# capture less.
|
|
||||||
capture_bs = (
|
|
||||||
list(range(1, 9))
|
|
||||||
+ list(range(10, 33, 2))
|
|
||||||
+ list(range(40, 64, 8))
|
|
||||||
+ list(range(80, 161, 16))
|
|
||||||
)
|
|
||||||
|
|
||||||
gpu_mem = get_device_memory_capacity()
|
|
||||||
if gpu_mem is not None:
|
|
||||||
if gpu_mem > 90 * 1024: # H200, H20
|
|
||||||
capture_bs += list(range(160, 257, 8))
|
|
||||||
if gpu_mem > 160 * 1000: # B200, MI300
|
|
||||||
capture_bs += list(range(256, 513, 16))
|
|
||||||
|
|
||||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
|
||||||
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
|
||||||
# is very small. We add more values here to make sure we capture the maximum bs.
|
|
||||||
capture_bs += [model_runner.req_to_token_pool.size]
|
|
||||||
|
|
||||||
mul_base = 1
|
|
||||||
|
|
||||||
if server_args.enable_two_batch_overlap:
|
|
||||||
mul_base *= 2
|
|
||||||
|
|
||||||
if require_gathered_buffer(server_args):
|
|
||||||
mul_base *= get_attention_tp_size()
|
|
||||||
|
|
||||||
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
|
|
||||||
|
|
||||||
if server_args.cuda_graph_max_bs:
|
|
||||||
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
|
||||||
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
|
||||||
capture_bs += list(
|
|
||||||
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
|
|
||||||
)
|
|
||||||
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
|
||||||
capture_bs = list(sorted(set(capture_bs)))
|
|
||||||
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
|
||||||
compile_bs = (
|
|
||||||
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
|
||||||
if server_args.enable_torch_compile
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
return capture_bs, compile_bs
|
|
||||||
|
|
||||||
|
|
||||||
# Reuse this memory pool across all device graph runners.
|
|
||||||
global_graph_memory_pool = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_global_graph_memory_pool():
|
|
||||||
return global_graph_memory_pool
|
|
||||||
|
|
||||||
|
|
||||||
def set_global_graph_memory_pool(val):
|
|
||||||
global global_graph_memory_pool
|
|
||||||
global_graph_memory_pool = val
|
|
||||||
|
|
||||||
|
|
||||||
class GraphRunner:
|
|
||||||
"""A GraphRunner is a base class to run the forward pass of a model with device graph and torch.compile."""
|
|
||||||
|
|
||||||
def __init__(self, model_runner: ModelRunner):
|
|
||||||
# Parse args
|
|
||||||
self.model_runner = model_runner
|
|
||||||
self.device = model_runner.device
|
|
||||||
self.device_module = torch.get_device_module(self.device)
|
|
||||||
self.graphs = {}
|
|
||||||
self.output_buffers = {}
|
|
||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
|
||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
|
||||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
|
||||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
|
||||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
|
||||||
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
|
||||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
|
||||||
self.enable_two_batch_overlap = (
|
|
||||||
model_runner.server_args.enable_two_batch_overlap
|
|
||||||
)
|
|
||||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
|
||||||
self.enable_profile_cuda_graph = (
|
|
||||||
model_runner.server_args.enable_profile_cuda_graph
|
|
||||||
)
|
|
||||||
self.tp_size = model_runner.server_args.tp_size
|
|
||||||
self.dp_size = model_runner.server_args.dp_size
|
|
||||||
self.pp_size = model_runner.server_args.pp_size
|
|
||||||
|
|
||||||
self.attn_tp_size = get_attention_tp_size()
|
|
||||||
self.attn_tp_rank = get_attention_tp_rank()
|
|
||||||
|
|
||||||
# Batch sizes to capture
|
|
||||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
|
||||||
rank0_log(f"Capture graph bs {self.capture_bs}")
|
|
||||||
self.capture_forward_mode = ForwardMode.DECODE
|
|
||||||
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
|
||||||
self.num_tokens_per_bs = 1
|
|
||||||
if model_runner.spec_algorithm.is_eagle():
|
|
||||||
if self.model_runner.is_draft_worker:
|
|
||||||
raise RuntimeError("This should not happen")
|
|
||||||
else:
|
|
||||||
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
|
||||||
self.num_tokens_per_bs = (
|
|
||||||
self.model_runner.server_args.speculative_num_draft_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
|
||||||
if model_runner.server_args.enable_return_hidden_states:
|
|
||||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
||||||
|
|
||||||
# Attention backend
|
|
||||||
self.max_bs = max(self.capture_bs)
|
|
||||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
|
||||||
self.model_runner.attn_backend.init_cuda_graph_state(
|
|
||||||
self.max_bs, self.max_num_token
|
|
||||||
)
|
|
||||||
self.seq_len_fill_value = (
|
|
||||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
|
||||||
)
|
|
||||||
|
|
||||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
|
||||||
self.encoder_len_fill_value = 0
|
|
||||||
self.seq_lens_cpu = torch.full(
|
|
||||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.enable_torch_compile:
|
|
||||||
set_torch_compile_config()
|
|
||||||
|
|
||||||
if self.model_runner.server_args.enable_lora:
|
|
||||||
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
|
||||||
|
|
||||||
# Graph inputs
|
|
||||||
with torch.device(self.device):
|
|
||||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
|
||||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
|
||||||
self.seq_lens = torch.full(
|
|
||||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
|
||||||
)
|
|
||||||
self.out_cache_loc = torch.zeros(
|
|
||||||
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
|
||||||
)
|
|
||||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
|
||||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
|
||||||
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
|
||||||
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
|
||||||
|
|
||||||
# pipeline parallelism
|
|
||||||
if self.pp_size > 1:
|
|
||||||
self.pp_proxy_tensors = {
|
|
||||||
"hidden_states": torch.zeros(
|
|
||||||
(self.max_bs, self.model_runner.model_config.hidden_size),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
),
|
|
||||||
"residual": torch.zeros(
|
|
||||||
(self.max_bs, self.model_runner.model_config.hidden_size),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Speculative_inference
|
|
||||||
if model_runner.spec_algorithm.is_eagle3():
|
|
||||||
self.model_runner.model.set_eagle3_layers_to_capture()
|
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
|
||||||
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
|
||||||
self.encoder_lens = torch.full(
|
|
||||||
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.encoder_lens = None
|
|
||||||
|
|
||||||
if self.require_gathered_buffer:
|
|
||||||
if self.require_mlp_tp_gather:
|
|
||||||
self.global_num_tokens_gpu = torch.zeros(
|
|
||||||
(self.dp_size,), dtype=torch.int32
|
|
||||||
)
|
|
||||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
|
||||||
(self.dp_size,), dtype=torch.int32
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert self.require_attn_tp_gather
|
|
||||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
|
||||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
|
||||||
(1,), dtype=torch.int32
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.global_num_tokens_gpu = None
|
|
||||||
self.global_num_tokens_for_logprob_gpu = None
|
|
||||||
|
|
||||||
self.custom_mask = torch.ones(
|
|
||||||
(
|
|
||||||
(self.seq_lens.sum().item() + self.max_num_token)
|
|
||||||
* self.num_tokens_per_bs
|
|
||||||
),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.next_token_logits_buffer = torch.zeros(
|
|
||||||
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
|
||||||
dtype=torch.float,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Capture
|
|
||||||
try:
|
|
||||||
with model_capture_mode():
|
|
||||||
self.capture()
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise Exception(
|
|
||||||
f"Capture device graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _cache_loc_dtype(self):
|
|
||||||
return torch.int64
|
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
|
||||||
if self.require_mlp_tp_gather:
|
|
||||||
cuda_graph_bs = (
|
|
||||||
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
|
||||||
if self.model_runner.spec_algorithm.is_eagle()
|
|
||||||
else max(forward_batch.global_num_tokens_cpu)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cuda_graph_bs = forward_batch.batch_size
|
|
||||||
|
|
||||||
is_bs_supported = (
|
|
||||||
cuda_graph_bs in self.graphs
|
|
||||||
if self.disable_padding
|
|
||||||
else cuda_graph_bs <= self.max_bs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.require_mlp_sync:
|
|
||||||
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
|
||||||
|
|
||||||
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
|
||||||
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
|
||||||
# because the full_text_row_masked_out_mask tensor will always be ones
|
|
||||||
is_encoder_lens_supported = (
|
|
||||||
torch.all(forward_batch.encoder_lens > 0)
|
|
||||||
if self.is_encoder_decoder
|
|
||||||
else True
|
|
||||||
)
|
|
||||||
|
|
||||||
requested_capture_hidden_mode = max(
|
|
||||||
forward_batch.capture_hidden_mode,
|
|
||||||
(
|
|
||||||
forward_batch.spec_info.capture_hidden_mode
|
|
||||||
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
|
||||||
is not None
|
|
||||||
else CaptureHiddenMode.NULL
|
|
||||||
),
|
|
||||||
)
|
|
||||||
capture_hidden_mode_matches = (
|
|
||||||
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
|
||||||
or requested_capture_hidden_mode == self.capture_hidden_mode
|
|
||||||
)
|
|
||||||
is_tbo_supported = (
|
|
||||||
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
is_bs_supported
|
|
||||||
and is_encoder_lens_supported
|
|
||||||
and is_tbo_supported
|
|
||||||
and capture_hidden_mode_matches
|
|
||||||
)
|
|
||||||
|
|
||||||
def capture(self) -> None:
|
|
||||||
profile_context = empty_context()
|
|
||||||
if self.enable_profile_cuda_graph:
|
|
||||||
profile_context = profile(
|
|
||||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
|
||||||
record_shapes=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trigger CUDA graph capture for specific shapes.
|
|
||||||
# Capture the large shapes first so that the smaller shapes
|
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
|
||||||
with freeze_gc(
|
|
||||||
self.model_runner.server_args.enable_cudagraph_gc
|
|
||||||
), graph_capture() as graph_capture_context:
|
|
||||||
with profile_context as prof:
|
|
||||||
self.stream = graph_capture_context.stream
|
|
||||||
avail_mem = get_available_gpu_memory(
|
|
||||||
self.model_runner.device,
|
|
||||||
self.model_runner.gpu_id,
|
|
||||||
empty_cache=False,
|
|
||||||
)
|
|
||||||
# Reverse the order to enable better memory sharing across cuda graphs.
|
|
||||||
capture_range = (
|
|
||||||
tqdm.tqdm(list(reversed(self.capture_bs)))
|
|
||||||
if get_tensor_model_parallel_rank() == 0
|
|
||||||
else reversed(self.capture_bs)
|
|
||||||
)
|
|
||||||
for i, bs in enumerate(capture_range):
|
|
||||||
if get_tensor_model_parallel_rank() == 0:
|
|
||||||
avail_mem = get_available_gpu_memory(
|
|
||||||
self.model_runner.device,
|
|
||||||
self.model_runner.gpu_id,
|
|
||||||
empty_cache=False,
|
|
||||||
)
|
|
||||||
capture_range.set_description(
|
|
||||||
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch_model(
|
|
||||||
self.model_runner.model,
|
|
||||||
bs in self.compile_bs,
|
|
||||||
num_tokens=bs * self.num_tokens_per_bs,
|
|
||||||
tp_group=self.model_runner.tp_group,
|
|
||||||
) as forward:
|
|
||||||
(
|
|
||||||
graph,
|
|
||||||
output_buffers,
|
|
||||||
) = self.capture_one_batch_size(bs, forward)
|
|
||||||
self.graphs[bs] = graph
|
|
||||||
self.output_buffers[bs] = output_buffers
|
|
||||||
|
|
||||||
# Save gemlite cache after each capture
|
|
||||||
save_gemlite_cache()
|
|
||||||
|
|
||||||
if self.enable_profile_cuda_graph:
|
|
||||||
log_message = (
|
|
||||||
"Sorted by CUDA Time:\n"
|
|
||||||
+ prof.key_averages(group_by_input_shape=True).table(
|
|
||||||
sort_by="cuda_time_total", row_limit=10
|
|
||||||
)
|
|
||||||
+ "\n\nSorted by CPU Time:\n"
|
|
||||||
+ prof.key_averages(group_by_input_shape=True).table(
|
|
||||||
sort_by="cpu_time_total", row_limit=10
|
|
||||||
)
|
|
||||||
)
|
|
||||||
logger.info(log_message)
|
|
||||||
|
|
||||||
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
|
||||||
with self.device_module.graph(graph, pool=pool, stream=stream):
|
|
||||||
out = run_once_fn()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _create_device_graph(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
|
||||||
graph = self._create_device_graph()
|
|
||||||
stream = self.stream
|
|
||||||
num_tokens = bs * self.num_tokens_per_bs
|
|
||||||
|
|
||||||
# Graph inputs
|
|
||||||
input_ids = self.input_ids[:num_tokens]
|
|
||||||
req_pool_indices = self.req_pool_indices[:bs]
|
|
||||||
seq_lens = self.seq_lens[:bs]
|
|
||||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
|
||||||
positions = self.positions[:num_tokens]
|
|
||||||
if self.is_encoder_decoder:
|
|
||||||
encoder_lens = self.encoder_lens[:bs]
|
|
||||||
else:
|
|
||||||
encoder_lens = None
|
|
||||||
mrope_positions = self.mrope_positions[:, :bs]
|
|
||||||
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
|
||||||
self.num_token_non_padded[...] = num_tokens
|
|
||||||
|
|
||||||
# pipeline parallelism
|
|
||||||
if self.pp_size > 1:
|
|
||||||
pp_proxy_tensors = PPProxyTensors(
|
|
||||||
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.require_mlp_tp_gather:
|
|
||||||
self.global_num_tokens_gpu.copy_(
|
|
||||||
torch.tensor(
|
|
||||||
[num_tokens] * self.dp_size,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
|
||||||
torch.tensor(
|
|
||||||
[num_tokens] * self.dp_size,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
global_dp_buffer_len = num_tokens * self.dp_size
|
|
||||||
elif self.require_attn_tp_gather:
|
|
||||||
self.global_num_tokens_gpu.copy_(
|
|
||||||
torch.tensor(
|
|
||||||
[num_tokens],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
|
||||||
torch.tensor(
|
|
||||||
[num_tokens],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
global_dp_buffer_len = num_tokens
|
|
||||||
else:
|
|
||||||
global_dp_buffer_len = None
|
|
||||||
|
|
||||||
spec_info = self.get_spec_info(num_tokens)
|
|
||||||
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
|
||||||
self.capture_hidden_mode = (
|
|
||||||
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.model_runner.server_args.enable_lora:
|
|
||||||
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
|
||||||
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
|
||||||
lora_ids = [None] * bs
|
|
||||||
else:
|
|
||||||
lora_ids = None
|
|
||||||
|
|
||||||
forward_batch = ForwardBatch(
|
|
||||||
forward_mode=self.capture_forward_mode,
|
|
||||||
batch_size=bs,
|
|
||||||
input_ids=input_ids,
|
|
||||||
req_pool_indices=req_pool_indices,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
next_token_logits_buffer=next_token_logits_buffer,
|
|
||||||
orig_seq_lens=seq_lens,
|
|
||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
|
||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
|
||||||
attn_backend=self.model_runner.attn_backend,
|
|
||||||
out_cache_loc=out_cache_loc,
|
|
||||||
seq_lens_sum=seq_lens.sum().item(),
|
|
||||||
encoder_lens=encoder_lens,
|
|
||||||
return_logprob=False,
|
|
||||||
positions=positions,
|
|
||||||
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
|
||||||
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
|
||||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
|
||||||
global_dp_buffer_len=global_dp_buffer_len,
|
|
||||||
mrope_positions=mrope_positions,
|
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
|
||||||
spec_info=spec_info,
|
|
||||||
capture_hidden_mode=self.capture_hidden_mode,
|
|
||||||
num_token_non_padded=self.num_token_non_padded,
|
|
||||||
global_forward_mode=self.capture_forward_mode,
|
|
||||||
lora_ids=lora_ids,
|
|
||||||
)
|
|
||||||
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
|
||||||
|
|
||||||
if lora_ids is not None:
|
|
||||||
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
|
||||||
|
|
||||||
# Attention backend
|
|
||||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
|
||||||
bs,
|
|
||||||
num_tokens,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
encoder_lens,
|
|
||||||
forward_batch.forward_mode,
|
|
||||||
forward_batch.spec_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run and capture
|
|
||||||
def run_once():
|
|
||||||
# Clean intermediate result cache for DP attention
|
|
||||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
|
||||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
|
||||||
|
|
||||||
kwargs = {}
|
|
||||||
if (
|
|
||||||
self.pp_size > 1
|
|
||||||
and "pp_proxy_tensors" in inspect.signature(forward).parameters
|
|
||||||
):
|
|
||||||
kwargs["pp_proxy_tensors"] = PPProxyTensors(
|
|
||||||
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_output_or_pp_proxy_tensors = forward(
|
|
||||||
input_ids,
|
|
||||||
forward_batch.positions,
|
|
||||||
forward_batch,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return logits_output_or_pp_proxy_tensors
|
|
||||||
|
|
||||||
for _ in range(2):
|
|
||||||
self.device_module.synchronize()
|
|
||||||
self.model_runner.tp_group.barrier()
|
|
||||||
run_once()
|
|
||||||
|
|
||||||
if get_global_graph_memory_pool() is None:
|
|
||||||
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
|
|
||||||
# Set graph pool id globally to be able to use symmetric memory
|
|
||||||
set_graph_pool_id(get_global_graph_memory_pool())
|
|
||||||
out = self._capture_graph(
|
|
||||||
graph, get_global_graph_memory_pool(), stream, run_once
|
|
||||||
)
|
|
||||||
|
|
||||||
return graph, out
|
|
||||||
|
|
||||||
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
|
||||||
|
|
||||||
# If the required capture_hidden_mode changes, we need to recapture the graph
|
|
||||||
|
|
||||||
# These are the different factors that can influence the capture_hidden_mode
|
|
||||||
capture_hidden_mode_required_by_forward_batch = (
|
|
||||||
forward_batch.capture_hidden_mode
|
|
||||||
)
|
|
||||||
capture_hidden_mode_required_by_spec_info = getattr(
|
|
||||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
|
||||||
)
|
|
||||||
capture_hidden_mode_required_for_returning_hidden_states = (
|
|
||||||
CaptureHiddenMode.FULL
|
|
||||||
if self.model_runner.server_args.enable_return_hidden_states
|
|
||||||
else CaptureHiddenMode.NULL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine the highest capture_hidden_mode required
|
|
||||||
# (If we have FULL, we can emulate LAST or NULL)
|
|
||||||
# (If we have LAST, we can emulate NULL)
|
|
||||||
required_capture_hidden_mode = max(
|
|
||||||
capture_hidden_mode_required_by_forward_batch,
|
|
||||||
capture_hidden_mode_required_by_spec_info,
|
|
||||||
capture_hidden_mode_required_for_returning_hidden_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
|
||||||
if self.capture_hidden_mode != required_capture_hidden_mode:
|
|
||||||
self.capture_hidden_mode = required_capture_hidden_mode
|
|
||||||
self.capture()
|
|
||||||
|
|
||||||
def replay_prepare(
|
|
||||||
self,
|
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
||||||
):
|
|
||||||
self.recapture_if_needed(forward_batch)
|
|
||||||
|
|
||||||
raw_bs = forward_batch.batch_size
|
|
||||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
|
||||||
|
|
||||||
# Pad
|
|
||||||
if self.require_mlp_tp_gather:
|
|
||||||
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
|
||||||
max_batch_size = (
|
|
||||||
max_num_tokens / self.num_tokens_per_bs
|
|
||||||
if self.model_runner.spec_algorithm.is_eagle()
|
|
||||||
else max_num_tokens
|
|
||||||
)
|
|
||||||
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
|
||||||
else:
|
|
||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
|
||||||
bs = self.capture_bs[index]
|
|
||||||
if bs != raw_bs:
|
|
||||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
|
||||||
self.out_cache_loc.zero_()
|
|
||||||
|
|
||||||
# Common inputs
|
|
||||||
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
|
||||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
|
||||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
|
||||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
|
||||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
|
||||||
|
|
||||||
seq_lens_cpu = None
|
|
||||||
if forward_batch.seq_lens_cpu is not None:
|
|
||||||
if bs != raw_bs:
|
|
||||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
|
||||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
|
||||||
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
|
||||||
|
|
||||||
if pp_proxy_tensors:
|
|
||||||
for key in self.pp_proxy_tensors.keys():
|
|
||||||
dim = pp_proxy_tensors[key].shape[0]
|
|
||||||
self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
|
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
|
||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
|
||||||
if forward_batch.mrope_positions is not None:
|
|
||||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
|
||||||
if self.require_gathered_buffer:
|
|
||||||
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
|
||||||
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
|
||||||
if enable_num_token_non_padded(self.model_runner.server_args):
|
|
||||||
num_token_non_padded = forward_batch.num_token_non_padded
|
|
||||||
if self.require_gathered_buffer:
|
|
||||||
tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
|
|
||||||
num_local_token_non_padded = torch.clamp(
|
|
||||||
num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
|
|
||||||
min=0,
|
|
||||||
max=tokens_per_rank,
|
|
||||||
)
|
|
||||||
self.num_token_non_padded.copy_(num_local_token_non_padded)
|
|
||||||
else:
|
|
||||||
self.num_token_non_padded.copy_(num_token_non_padded)
|
|
||||||
if self.enable_two_batch_overlap:
|
|
||||||
self.tbo_plugin.replay_prepare(
|
|
||||||
forward_mode=self.capture_forward_mode,
|
|
||||||
bs=bs,
|
|
||||||
num_token_non_padded=len(forward_batch.input_ids),
|
|
||||||
spec_info=forward_batch.spec_info,
|
|
||||||
)
|
|
||||||
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
|
||||||
forward_batch.spec_info.custom_mask = self.custom_mask
|
|
||||||
# Attention backend
|
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
|
||||||
bs,
|
|
||||||
self.req_pool_indices[:bs],
|
|
||||||
self.seq_lens[:bs],
|
|
||||||
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
|
||||||
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
|
||||||
self.capture_forward_mode,
|
|
||||||
forward_batch.spec_info,
|
|
||||||
seq_lens_cpu=seq_lens_cpu,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store fields
|
|
||||||
self.raw_bs = raw_bs
|
|
||||||
self.raw_num_token = raw_num_token
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
def replay(
|
|
||||||
self,
|
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
skip_attn_backend_init: bool = False,
|
|
||||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
||||||
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
|
||||||
if not skip_attn_backend_init:
|
|
||||||
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
|
||||||
else:
|
|
||||||
# In speculative decoding, these two fields are still needed.
|
|
||||||
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
|
||||||
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
|
||||||
|
|
||||||
# Replay
|
|
||||||
self.graphs[self.bs].replay()
|
|
||||||
|
|
||||||
output = self.output_buffers[self.bs]
|
|
||||||
if isinstance(output, LogitsProcessorOutput):
|
|
||||||
return LogitsProcessorOutput(
|
|
||||||
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
|
||||||
hidden_states=(
|
|
||||||
output.hidden_states[: self.raw_num_token]
|
|
||||||
if output.hidden_states is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert isinstance(output, PPProxyTensors)
|
|
||||||
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
|
||||||
|
|
||||||
def get_spec_info(self, num_tokens: int):
|
|
||||||
spec_info = None
|
|
||||||
if self.model_runner.spec_algorithm.is_eagle():
|
|
||||||
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
|
||||||
|
|
||||||
if self.model_runner.is_draft_worker:
|
|
||||||
raise RuntimeError("This should not happen.")
|
|
||||||
else:
|
|
||||||
spec_info = EagleVerifyInput(
|
|
||||||
draft_token=None,
|
|
||||||
custom_mask=self.custom_mask,
|
|
||||||
positions=None,
|
|
||||||
retrive_index=None,
|
|
||||||
retrive_next_token=None,
|
|
||||||
retrive_next_sibling=None,
|
|
||||||
retrive_cum_len=None,
|
|
||||||
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
|
||||||
topk=self.model_runner.server_args.speculative_eagle_topk,
|
|
||||||
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
|
||||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
|
||||||
seq_lens_sum=None,
|
|
||||||
seq_lens_cpu=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return spec_info
|
|
||||||
|
|
||||||
|
|
||||||
GRAPH_CAPTURE_FAILED_MSG = (
|
|
||||||
"Possible solutions:\n"
|
|
||||||
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
|
||||||
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
|
||||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
|
||||||
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
|
||||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
|
||||||
)
|
|
||||||
@@ -91,7 +91,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
||||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||||
@@ -342,12 +341,9 @@ class ModelRunner:
|
|||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
self.init_cublas()
|
self.init_cublas()
|
||||||
self.init_attention_backend()
|
self.init_attention_backend()
|
||||||
self.init_device_graphs()
|
self.init_cuda_graphs()
|
||||||
elif self.device == "npu":
|
|
||||||
self.init_attention_backend()
|
|
||||||
self.init_device_graphs()
|
|
||||||
else:
|
else:
|
||||||
self.graph_runner = None
|
self.cuda_graph_runner = None
|
||||||
self.cuda_graph_mem_usage = 0
|
self.cuda_graph_mem_usage = 0
|
||||||
self.init_attention_backend()
|
self.init_attention_backend()
|
||||||
|
|
||||||
@@ -921,8 +917,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# We need to get device after patch otherwise the device would be wrong
|
# We need to get device after patch otherwise the device would be wrong
|
||||||
self.device_module = torch.get_device_module(self.device)
|
infered_device = torch.cuda.current_device()
|
||||||
infered_device = self.device_module.current_device()
|
|
||||||
|
|
||||||
named_tensors = [
|
named_tensors = [
|
||||||
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
||||||
@@ -1590,9 +1585,9 @@ class ModelRunner:
|
|||||||
.cuda()
|
.cuda()
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_device_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
"""Capture cuda graphs."""
|
"""Capture cuda graphs."""
|
||||||
self.graph_runner = None
|
self.cuda_graph_runner = None
|
||||||
self.cuda_graph_mem_usage = 0
|
self.cuda_graph_mem_usage = 0
|
||||||
|
|
||||||
if not self.is_generation:
|
if not self.is_generation:
|
||||||
@@ -1607,9 +1602,8 @@ class ModelRunner:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||||
)
|
)
|
||||||
self.graph_runner = (
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
|
|
||||||
)
|
|
||||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
self.cuda_graph_mem_usage = before_mem - after_mem
|
self.cuda_graph_mem_usage = before_mem - after_mem
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1761,11 +1755,11 @@ class ModelRunner:
|
|||||||
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||||
can_run_cuda_graph = bool(
|
can_run_cuda_graph = bool(
|
||||||
forward_batch.forward_mode.is_cuda_graph()
|
forward_batch.forward_mode.is_cuda_graph()
|
||||||
and self.graph_runner
|
and self.cuda_graph_runner
|
||||||
and self.graph_runner.can_run(forward_batch)
|
and self.cuda_graph_runner.can_run(forward_batch)
|
||||||
)
|
)
|
||||||
if can_run_cuda_graph:
|
if can_run_cuda_graph:
|
||||||
ret = self.graph_runner.replay(
|
ret = self.cuda_graph_runner.replay(
|
||||||
forward_batch,
|
forward_batch,
|
||||||
skip_attn_backend_init=skip_attn_backend_init,
|
skip_attn_backend_init=skip_attn_backend_init,
|
||||||
pp_proxy_tensors=pp_proxy_tensors,
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Run the model with npu graph and torch.compile."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.model_executor.graph_runner import GraphRunner
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
||||||
|
|
||||||
|
|
||||||
class NPUGraphRunner(GraphRunner):
|
|
||||||
"""A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
|
|
||||||
|
|
||||||
def __init__(self, model_runner: ModelRunner):
|
|
||||||
super().__init__(model_runner)
|
|
||||||
|
|
||||||
def _create_device_graph(self):
|
|
||||||
return torch.npu.NPUGraph()
|
|
||||||
|
|
||||||
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
|
||||||
with torch.npu.graph(
|
|
||||||
graph,
|
|
||||||
pool=pool,
|
|
||||||
stream=stream,
|
|
||||||
auto_dispatch_capture=True,
|
|
||||||
):
|
|
||||||
out = run_once_fn()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _update_inputs(self, seq_lens):
|
|
||||||
self.graphs[self.bs].update(
|
|
||||||
cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _cache_loc_dtype(self):
|
|
||||||
return torch.int32
|
|
||||||
|
|
||||||
def replay(
|
|
||||||
self,
|
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
skip_attn_backend_init: bool = False,
|
|
||||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
||||||
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
|
||||||
if not skip_attn_backend_init:
|
|
||||||
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
|
||||||
else:
|
|
||||||
# In speculative decoding, these two fields are still needed.
|
|
||||||
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
|
||||||
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
|
||||||
|
|
||||||
# Replay
|
|
||||||
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
|
|
||||||
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
|
||||||
thread.start()
|
|
||||||
self.graphs[self.bs].replay()
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
output = self.output_buffers[self.bs]
|
|
||||||
if isinstance(output, LogitsProcessorOutput):
|
|
||||||
return LogitsProcessorOutput(
|
|
||||||
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
|
||||||
hidden_states=(
|
|
||||||
output.hidden_states[: self.raw_num_token]
|
|
||||||
if output.hidden_states is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert isinstance(output, PPProxyTensors)
|
|
||||||
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
|
||||||
@@ -1200,7 +1200,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
):
|
):
|
||||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
|
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
|
||||||
|
|||||||
@@ -68,8 +68,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.deepseek_v2 import (
|
from sglang.srt.models.deepseek_v2 import (
|
||||||
DeepseekV2DecoderLayer,
|
DeepseekV2DecoderLayer,
|
||||||
|
|||||||
@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
|
||||||
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
|
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
|
||||||
self._batch_image_inputs(forward_batch)
|
self._batch_image_inputs(forward_batch)
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
||||||
from sglang.srt.models.qwen2 import Qwen2Model
|
from sglang.srt.models.qwen2 import Qwen2Model
|
||||||
|
|||||||
@@ -52,8 +52,8 @@ from sglang.srt.layers.rotary_embedding import get_rope
|
|||||||
from sglang.srt.layers.utils import get_layer_id
|
from sglang.srt.layers.utils import get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
|
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||||
|
|||||||
@@ -6,20 +6,20 @@ from typing import TYPE_CHECKING, Callable
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||||
CaptureHiddenMode,
|
CudaGraphRunner,
|
||||||
ForwardBatch,
|
|
||||||
ForwardMode,
|
|
||||||
)
|
|
||||||
from sglang.srt.model_executor.graph_runner import (
|
|
||||||
GRAPH_CAPTURE_FAILED_MSG,
|
|
||||||
get_batch_sizes_to_capture,
|
get_batch_sizes_to_capture,
|
||||||
get_global_graph_memory_pool,
|
get_global_graph_memory_pool,
|
||||||
model_capture_mode,
|
model_capture_mode,
|
||||||
set_global_graph_memory_pool,
|
set_global_graph_memory_pool,
|
||||||
set_torch_compile_config,
|
set_torch_compile_config,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
CaptureHiddenMode,
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
)
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
require_attn_tp_gather,
|
require_attn_tp_gather,
|
||||||
@@ -121,7 +121,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.capture()
|
self.capture()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
|
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
|||||||
@@ -6,14 +6,9 @@ from typing import TYPE_CHECKING, Callable
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||||
CaptureHiddenMode,
|
CudaGraphRunner,
|
||||||
ForwardBatch,
|
|
||||||
ForwardMode,
|
|
||||||
)
|
|
||||||
from sglang.srt.model_executor.graph_runner import (
|
|
||||||
GRAPH_CAPTURE_FAILED_MSG,
|
|
||||||
LogitsProcessorOutput,
|
LogitsProcessorOutput,
|
||||||
get_batch_sizes_to_capture,
|
get_batch_sizes_to_capture,
|
||||||
get_global_graph_memory_pool,
|
get_global_graph_memory_pool,
|
||||||
@@ -21,6 +16,11 @@ from sglang.srt.model_executor.graph_runner import (
|
|||||||
set_global_graph_memory_pool,
|
set_global_graph_memory_pool,
|
||||||
set_torch_compile_config,
|
set_torch_compile_config,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
CaptureHiddenMode,
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
)
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
require_attn_tp_gather,
|
require_attn_tp_gather,
|
||||||
@@ -149,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.capture()
|
self.capture()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}"
|
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
|||||||
@@ -229,17 +229,6 @@ suite_amd = {
|
|||||||
TestFile("test_wave_attention_kernels.py", 2),
|
TestFile("test_wave_attention_kernels.py", 2),
|
||||||
TestFile("test_wave_attention_backend.py", 150),
|
TestFile("test_wave_attention_backend.py", 150),
|
||||||
],
|
],
|
||||||
"per-commit-1-ascend-npu": [
|
|
||||||
TestFile("test_ascend_tp1_bf16.py", 400),
|
|
||||||
TestFile("test_ascend_graph_tp1_bf16.py", 400),
|
|
||||||
],
|
|
||||||
"per-commit-2-ascend-npu": [
|
|
||||||
TestFile("test_ascend_tp2_bf16.py", 400),
|
|
||||||
TestFile("test_ascend_graph_tp2_bf16.py", 400),
|
|
||||||
],
|
|
||||||
"per-commit-4-ascend-npu": [
|
|
||||||
TestFile("test_ascend_mla_w8a8int8.py", 400),
|
|
||||||
],
|
|
||||||
"per-commit-2-gpu-amd": [
|
"per-commit-2-gpu-amd": [
|
||||||
TestFile("lora/test_lora_tp.py", 116),
|
TestFile("lora/test_lora_tp.py", 116),
|
||||||
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
|
||||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
|
||||||
from sglang.test.test_utils import (
|
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
CustomTestCase,
|
|
||||||
is_in_ci,
|
|
||||||
popen_launch_server,
|
|
||||||
run_bench_offline_throughput,
|
|
||||||
)
|
|
||||||
|
|
||||||
TEST_MODEL_MATRIX = {
|
|
||||||
"Qwen/Qwen2.5-7B-Instruct": {
|
|
||||||
"accuracy": 0.85,
|
|
||||||
"latency": 150,
|
|
||||||
"output_throughput": 30,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TestAscendGraphTp1Bf16(CustomTestCase):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.models = TEST_MODEL_MATRIX.keys()
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
|
||||||
cls.common_args = [
|
|
||||||
"--trust-remote-code",
|
|
||||||
"--mem-fraction-static",
|
|
||||||
0.8,
|
|
||||||
"--attention-backend",
|
|
||||||
"ascend",
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_a_gsm8k(self):
|
|
||||||
for model in self.models:
|
|
||||||
with self.subTest(model=model):
|
|
||||||
print(f"##=== Testing accuracy: {model} ===##")
|
|
||||||
|
|
||||||
process = popen_launch_server(
|
|
||||||
model,
|
|
||||||
self.base_url,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
other_args=[
|
|
||||||
*self.common_args,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
args = SimpleNamespace(
|
|
||||||
num_shots=5,
|
|
||||||
data_path=None,
|
|
||||||
num_questions=1319,
|
|
||||||
max_new_tokens=512,
|
|
||||||
parallel=128,
|
|
||||||
host=f"http://{self.url.hostname}",
|
|
||||||
port=int(self.url.port),
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval_few_shot_gsm8k(args)
|
|
||||||
self.assertGreaterEqual(
|
|
||||||
metrics["accuracy"],
|
|
||||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
kill_process_tree(process.pid)
|
|
||||||
|
|
||||||
def test_b_throughput(self):
|
|
||||||
for model in self.models:
|
|
||||||
with self.subTest(model=model):
|
|
||||||
print(f"##=== Testing throughput: {model} ===##")
|
|
||||||
|
|
||||||
output_throughput = run_bench_offline_throughput(
|
|
||||||
model,
|
|
||||||
[
|
|
||||||
*self.common_args,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
|
||||||
|
|
||||||
if is_in_ci():
|
|
||||||
self.assertGreater(
|
|
||||||
output_throughput,
|
|
||||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
|
||||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
|
||||||
from sglang.test.test_utils import (
|
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
CustomTestCase,
|
|
||||||
is_in_ci,
|
|
||||||
popen_launch_server,
|
|
||||||
run_bench_offline_throughput,
|
|
||||||
)
|
|
||||||
|
|
||||||
TEST_MODEL_MATRIX = {
|
|
||||||
"Qwen/Qwen2.5-7B-Instruct": {
|
|
||||||
"accuracy": 0.85,
|
|
||||||
"latency": 180,
|
|
||||||
"output_throughput": 20,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TestAscendGraphTp2Bf16(CustomTestCase):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.models = TEST_MODEL_MATRIX.keys()
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
|
||||||
cls.common_args = [
|
|
||||||
"--trust-remote-code",
|
|
||||||
"--mem-fraction-static",
|
|
||||||
0.8,
|
|
||||||
"--attention-backend",
|
|
||||||
"ascend",
|
|
||||||
"--tp-size",
|
|
||||||
2,
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_a_gsm8k(self):
|
|
||||||
for model in self.models:
|
|
||||||
with self.subTest(model=model):
|
|
||||||
print(f"##=== Testing accuracy: {model} ===##")
|
|
||||||
|
|
||||||
process = popen_launch_server(
|
|
||||||
model,
|
|
||||||
self.base_url,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
other_args=[
|
|
||||||
*self.common_args,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
args = SimpleNamespace(
|
|
||||||
num_shots=5,
|
|
||||||
data_path=None,
|
|
||||||
num_questions=1319,
|
|
||||||
max_new_tokens=512,
|
|
||||||
parallel=128,
|
|
||||||
host=f"http://{self.url.hostname}",
|
|
||||||
port=int(self.url.port),
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval_few_shot_gsm8k(args)
|
|
||||||
self.assertGreaterEqual(
|
|
||||||
metrics["accuracy"],
|
|
||||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
kill_process_tree(process.pid)
|
|
||||||
|
|
||||||
def test_b_throughput(self):
|
|
||||||
for model in self.models:
|
|
||||||
with self.subTest(model=model):
|
|
||||||
print(f"##=== Testing throughput: {model} ===##")
|
|
||||||
|
|
||||||
output_throughput = run_bench_offline_throughput(
|
|
||||||
model,
|
|
||||||
[
|
|
||||||
*self.common_args,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
|
||||||
|
|
||||||
if is_in_ci():
|
|
||||||
self.assertGreater(
|
|
||||||
output_throughput,
|
|
||||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Reference in New Issue
Block a user