[Feature] Comprehensive Hybrid Parallelism Support (#6389)
This commit is contained in:
@@ -20,6 +20,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
@@ -39,8 +44,9 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.dp_size = self.model_runner.dp_size
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.topk = model_runner.server_args.speculative_eagle_topk
|
||||
@@ -88,8 +94,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
||||
if self.require_gathered_buffer:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
@@ -97,12 +102,19 @@ class EAGLEDraftCudaGraphRunner:
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
@@ -114,8 +126,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention:
|
||||
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
|
||||
if self.require_mlp_tp_gather:
|
||||
if not forward_batch.can_run_dp_cuda_graph:
|
||||
return False
|
||||
total_batch_size = (
|
||||
@@ -153,7 +164,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
topk_index = self.topk_index[:num_seqs]
|
||||
hidden_states = self.hidden_states[:num_seqs]
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
@@ -177,6 +188,24 @@ class EAGLEDraftCudaGraphRunner:
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
@@ -259,7 +288,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
# Pad
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
@@ -286,7 +315,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
||||
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_gathered_buffer:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
forward_batch.global_num_tokens_for_logprob_gpu
|
||||
|
||||
@@ -21,6 +21,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
@@ -35,8 +40,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||
@@ -92,7 +98,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
||||
)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_gathered_buffer:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
@@ -100,13 +106,19 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
@@ -117,7 +129,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
if not forward_batch.can_run_dp_cuda_graph:
|
||||
return False
|
||||
total_batch_size = (
|
||||
@@ -160,7 +172,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
positions = self.positions[:num_tokens]
|
||||
hidden_states = self.hidden_states[:num_tokens]
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
@@ -184,6 +196,24 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
@@ -270,7 +300,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
# in the batch, which will not be counted as num_seqs
|
||||
raw_bs = forward_batch.batch_size
|
||||
num_tokens = forward_batch.input_ids.shape[0]
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
@@ -299,7 +329,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_gathered_buffer:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
forward_batch.global_num_tokens_for_logprob_gpu
|
||||
|
||||
Reference in New Issue
Block a user