[Feature] Comprehensive Hybrid Parallelism Support (#6389)

This commit is contained in:
Cheng Wan
2025-06-20 14:43:11 -07:00
committed by GitHub
parent 0998808009
commit e879d8b7a8
14 changed files with 3689 additions and 108 deletions

View File

@@ -20,6 +20,11 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_tp_gather,
)
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -39,8 +44,9 @@ class EAGLEDraftCudaGraphRunner:
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size
self.topk = model_runner.server_args.speculative_eagle_topk
@@ -88,8 +94,7 @@ class EAGLEDraftCudaGraphRunner:
dtype=self.model_runner.dtype,
)
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
@@ -97,12 +102,19 @@ class EAGLEDraftCudaGraphRunner:
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
# Capture
try:
@@ -114,8 +126,7 @@ class EAGLEDraftCudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
@@ -153,7 +164,7 @@ class EAGLEDraftCudaGraphRunner:
topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs]
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
@@ -177,6 +188,24 @@ class EAGLEDraftCudaGraphRunner:
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
@@ -259,7 +288,7 @@ class EAGLEDraftCudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
@@ -286,7 +315,7 @@ class EAGLEDraftCudaGraphRunner:
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu

View File

@@ -21,6 +21,11 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_tp_gather,
)
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -35,8 +40,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
@@ -92,7 +98,7 @@ class EAGLEDraftExtendCudaGraphRunner:
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
@@ -100,13 +106,19 @@ class EAGLEDraftExtendCudaGraphRunner:
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
# Capture
try:
with model_capture_mode():
@@ -117,7 +129,7 @@ class EAGLEDraftExtendCudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
@@ -160,7 +172,7 @@ class EAGLEDraftExtendCudaGraphRunner:
positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens]
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
@@ -184,6 +196,24 @@ class EAGLEDraftExtendCudaGraphRunner:
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
@@ -270,7 +300,7 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0]
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
@@ -299,7 +329,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu