Piecewise CUDA Graph Support & Torch Compile Backend (#10062)
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
This commit is contained in:
@@ -108,8 +108,15 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
)
|
||||
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
||||
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
|
||||
PiecewiseCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
||||
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
||||
@@ -307,6 +314,26 @@ class ModelRunner:
|
||||
self._model_update_group = {}
|
||||
self._weights_send_group = {}
|
||||
|
||||
if (
|
||||
self.server_args.enable_piecewise_cuda_graph
|
||||
and self.can_run_piecewise_cuda_graph()
|
||||
):
|
||||
self.attention_layers = []
|
||||
for layer in self.model.model.layers:
|
||||
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
|
||||
self.attention_layers.append(layer.self_attn.attn)
|
||||
if len(self.attention_layers) < self.model_config.num_hidden_layers:
|
||||
# TODO(yuwei): support Non-Standard GQA
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Disable piecewise CUDA graph because some layers do not apply Standard GQA",
|
||||
)
|
||||
self.piecewise_cuda_graph_runner = None
|
||||
else:
|
||||
self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
|
||||
else:
|
||||
self.piecewise_cuda_graph_runner = None
|
||||
|
||||
def initialize(self, min_per_gpu_memory: float):
|
||||
server_args = self.server_args
|
||||
|
||||
@@ -692,6 +719,7 @@ class ModelRunner:
|
||||
pipeline_model_parallel_size=self.pp_size,
|
||||
expert_model_parallel_size=self.moe_ep_size,
|
||||
duplicate_tp_group=self.server_args.enable_pdmux,
|
||||
torch_compile=self.server_args.enable_piecewise_cuda_graph,
|
||||
)
|
||||
initialize_dp_attention(
|
||||
server_args=self.server_args,
|
||||
@@ -1411,6 +1439,27 @@ class ModelRunner:
|
||||
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
|
||||
)
|
||||
|
||||
def can_run_piecewise_cuda_graph(self):
|
||||
if self.server_args.disable_cuda_graph:
|
||||
log_info_on_rank0(
|
||||
logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
|
||||
)
|
||||
return False
|
||||
if self.server_args.enable_torch_compile:
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
|
||||
)
|
||||
return False
|
||||
if self.pp_size > 1:
|
||||
# TODO(yuwei): support PP
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def init_memory_pool(
|
||||
self,
|
||||
total_gpu_memory: int,
|
||||
@@ -1932,6 +1981,11 @@ class ModelRunner:
|
||||
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
|
||||
if not self.is_generation:
|
||||
kwargs["get_embedding"] = True
|
||||
|
||||
if self.piecewise_cuda_graph_runner is not None:
|
||||
if self.piecewise_cuda_graph_runner.can_run(forward_batch):
|
||||
return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
|
||||
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
|
||||
Reference in New Issue
Block a user