[Feature] Comprehensive Hybrid Parallelism Support (#6389)
This commit is contained in:
@@ -46,6 +46,9 @@ from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
get_device_memory_capacity,
|
||||
rank0_log,
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -207,8 +210,9 @@ class CudaGraphRunner:
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.enable_two_batch_overlap = (
|
||||
model_runner.server_args.enable_two_batch_overlap
|
||||
)
|
||||
@@ -299,18 +303,28 @@ class CudaGraphRunner:
|
||||
else:
|
||||
self.encoder_lens = None
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
if self.require_gathered_buffer:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.num_tokens_per_bs,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
@@ -322,7 +336,7 @@ class CudaGraphRunner:
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
@@ -459,7 +473,7 @@ class CudaGraphRunner:
|
||||
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
||||
)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
@@ -472,6 +486,16 @@ class CudaGraphRunner:
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
@@ -607,7 +631,7 @@ class CudaGraphRunner:
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
# Pad
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
@@ -642,7 +666,7 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_gathered_buffer:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
if enable_num_token_non_padded(self.model_runner.server_args):
|
||||
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
||||
|
||||
Reference in New Issue
Block a user