[Feature] Comprehensive Hybrid Parallelism Support (#6389)

This commit is contained in:
Cheng Wan
2025-06-20 14:43:11 -07:00
committed by GitHub
parent 0998808009
commit e879d8b7a8
14 changed files with 3689 additions and 108 deletions

View File

@@ -46,6 +46,9 @@ from sglang.srt.utils import (
get_available_gpu_memory,
get_device_memory_capacity,
rank0_log,
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_tp_gather,
)
logger = logging.getLogger(__name__)
@@ -207,8 +210,9 @@ class CudaGraphRunner:
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap
)
@@ -299,18 +303,28 @@ class CudaGraphRunner:
else:
self.encoder_lens = None
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
if self.require_gathered_buffer:
if self.require_mlp_tp_gather:
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
else:
assert self.require_attn_tp_gather
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
# Capture
try:
@@ -322,7 +336,7 @@ class CudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
@@ -459,7 +473,7 @@ class CudaGraphRunner:
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
@@ -472,6 +486,16 @@ class CudaGraphRunner:
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
else:
global_num_tokens = None
gathered_buffer = None
@@ -607,7 +631,7 @@ class CudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
@@ -642,7 +666,7 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if enable_num_token_non_padded(self.model_runner.server_args):
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)