[Feature] Comprehensive Hybrid Parallelism Support (#6389)
This commit is contained in:
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
|||||||
configure_logger,
|
configure_logger,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
|
require_mlp_sync,
|
||||||
|
require_mlp_tp_gather,
|
||||||
set_gpu_proc_affinity,
|
set_gpu_proc_affinity,
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
)
|
)
|
||||||
@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
|
|||||||
enable_custom_logit_processor=False,
|
enable_custom_logit_processor=False,
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend()
|
batch.prepare_for_extend()
|
||||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||||
logits_output, _ = model_runner.forward(forward_batch)
|
logits_output, _ = model_runner.forward(forward_batch)
|
||||||
@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
|
|||||||
def decode(input_token_ids, batch, model_runner):
|
def decode(input_token_ids, batch, model_runner):
|
||||||
batch.output_ids = input_token_ids
|
batch.output_ids = input_token_ids
|
||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||||
logits_output, _ = model_runner.forward(forward_batch)
|
logits_output, _ = model_runner.forward(forward_batch)
|
||||||
@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
|
|||||||
return next_token_ids, logits_output.next_token_logits
|
return next_token_ids, logits_output.next_token_logits
|
||||||
|
|
||||||
|
|
||||||
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
||||||
if model_runner.server_args.enable_dp_attention:
|
if require_mlp_sync(model_runner.server_args):
|
||||||
Scheduler.prepare_dp_attn_batch_raw(
|
Scheduler.prepare_mlp_sync_batch_raw(
|
||||||
batch,
|
batch,
|
||||||
dp_size=model_runner.server_args.dp_size,
|
dp_size=model_runner.server_args.dp_size,
|
||||||
attn_tp_size=1,
|
attn_tp_size=1,
|
||||||
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
|
|
||||||
tp_cpu_group=model_runner.tp_group.cpu_group,
|
tp_cpu_group=model_runner.tp_group.cpu_group,
|
||||||
get_idle_batch=None,
|
get_idle_batch=None,
|
||||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||||
speculative_num_draft_tokens=None,
|
speculative_num_draft_tokens=None,
|
||||||
|
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
from sglang.srt.utils import require_mlp_sync
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -649,10 +650,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
batch = self.get_next_disagg_decode_batch_to_run()
|
batch = self.get_next_disagg_decode_batch_to_run()
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
prepare_dp_attn_flag = (
|
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
||||||
self.server_args.enable_dp_attention
|
|
||||||
or self.server_args.enable_sp_layernorm
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
# Generate fake extend output.
|
# Generate fake extend output.
|
||||||
@@ -661,14 +659,14 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.stream_output(
|
self.stream_output(
|
||||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||||
)
|
)
|
||||||
if prepare_dp_attn_flag:
|
if prepare_mlp_sync_flag:
|
||||||
self._prepare_idle_batch_and_run(None)
|
self._prepare_idle_batch_and_run(None)
|
||||||
else:
|
else:
|
||||||
if prepare_dp_attn_flag:
|
if prepare_mlp_sync_flag:
|
||||||
self.prepare_dp_attn_batch(batch)
|
self.prepare_mlp_sync_batch(batch)
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
elif prepare_dp_attn_flag:
|
elif prepare_mlp_sync_flag:
|
||||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||||
|
|
||||||
if batch is None and (
|
if batch is None and (
|
||||||
@@ -699,10 +697,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
last_batch_in_queue = False
|
last_batch_in_queue = False
|
||||||
|
|
||||||
prepare_dp_attn_flag = (
|
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
||||||
self.server_args.enable_dp_attention
|
|
||||||
or self.server_args.enable_sp_layernorm
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
# Generate fake extend output.
|
# Generate fake extend output.
|
||||||
@@ -711,7 +706,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.stream_output(
|
self.stream_output(
|
||||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||||
)
|
)
|
||||||
if prepare_dp_attn_flag:
|
if prepare_mlp_sync_flag:
|
||||||
batch_, result = self._prepare_idle_batch_and_run(
|
batch_, result = self._prepare_idle_batch_and_run(
|
||||||
None, delay_process=True
|
None, delay_process=True
|
||||||
)
|
)
|
||||||
@@ -719,8 +714,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
result_queue.append((batch_.copy(), result))
|
result_queue.append((batch_.copy(), result))
|
||||||
last_batch_in_queue = True
|
last_batch_in_queue = True
|
||||||
else:
|
else:
|
||||||
if prepare_dp_attn_flag:
|
if prepare_mlp_sync_flag:
|
||||||
self.prepare_dp_attn_batch(batch)
|
self.prepare_mlp_sync_batch(batch)
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
result_queue.append((batch.copy(), result))
|
result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
@@ -735,7 +730,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||||
last_batch_in_queue = True
|
last_batch_in_queue = True
|
||||||
|
|
||||||
elif prepare_dp_attn_flag:
|
elif prepare_mlp_sync_flag:
|
||||||
batch, result = self._prepare_idle_batch_and_run(
|
batch, result = self._prepare_idle_batch_and_run(
|
||||||
None, delay_process=True
|
None, delay_process=True
|
||||||
)
|
)
|
||||||
@@ -765,13 +760,13 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
self.last_batch_in_queue = last_batch_in_queue
|
self.last_batch_in_queue = last_batch_in_queue
|
||||||
|
|
||||||
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
||||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||||
result = None
|
result = None
|
||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
if not delay_process:
|
if not delay_process:
|
||||||
self.process_batch_result(batch, result)
|
self.prepare_mlp_sync_batch(batch, result)
|
||||||
return batch, result
|
return batch, result
|
||||||
|
|
||||||
def get_next_disagg_decode_batch_to_run(
|
def get_next_disagg_decode_batch_to_run(
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
from sglang.srt.utils import require_mlp_sync
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@@ -274,12 +275,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
self.process_prefill_chunk()
|
self.process_prefill_chunk()
|
||||||
batch = self.get_new_batch_prefill()
|
batch = self.get_new_batch_prefill()
|
||||||
|
|
||||||
# Handle DP attention
|
if require_mlp_sync(self.server_args):
|
||||||
if (
|
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||||
self.server_args.enable_dp_attention
|
|
||||||
or self.server_args.enable_sp_layernorm
|
|
||||||
):
|
|
||||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
@@ -312,12 +309,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
self.process_prefill_chunk()
|
self.process_prefill_chunk()
|
||||||
batch = self.get_new_batch_prefill()
|
batch = self.get_new_batch_prefill()
|
||||||
|
|
||||||
# Handle DP attention
|
if require_mlp_sync(self.server_args):
|
||||||
if (
|
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||||
self.server_args.enable_dp_attention
|
|
||||||
or self.server_args.enable_sp_layernorm
|
|
||||||
):
|
|
||||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
|
|||||||
@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
attn_tp_reduce_scatter,
|
attn_tp_reduce_scatter,
|
||||||
dp_gather_partial,
|
dp_gather_partial,
|
||||||
dp_scatter,
|
dp_scatter,
|
||||||
|
get_attention_dp_size,
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
get_local_attention_dp_size,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -229,7 +229,7 @@ class CommunicateContext:
|
|||||||
process_group_sizes: Dict[ScatterMode, int]
|
process_group_sizes: Dict[ScatterMode, int]
|
||||||
attn_tp_rank: int
|
attn_tp_rank: int
|
||||||
attn_tp_size: int
|
attn_tp_size: int
|
||||||
local_attn_dp_size: int
|
attn_dp_size: int
|
||||||
tp_size: int
|
tp_size: int
|
||||||
|
|
||||||
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
||||||
@@ -239,7 +239,7 @@ class CommunicateContext:
|
|||||||
def init_new(cls):
|
def init_new(cls):
|
||||||
attn_tp_rank = get_attention_tp_rank()
|
attn_tp_rank = get_attention_tp_rank()
|
||||||
attn_tp_size = get_attention_tp_size()
|
attn_tp_size = get_attention_tp_size()
|
||||||
local_attn_dp_size = get_local_attention_dp_size()
|
attn_dp_size = get_attention_dp_size()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
process_group_sizes = {
|
process_group_sizes = {
|
||||||
ScatterMode.SCATTERED: 1,
|
ScatterMode.SCATTERED: 1,
|
||||||
@@ -251,7 +251,7 @@ class CommunicateContext:
|
|||||||
process_group_sizes=process_group_sizes,
|
process_group_sizes=process_group_sizes,
|
||||||
attn_tp_rank=attn_tp_rank,
|
attn_tp_rank=attn_tp_rank,
|
||||||
attn_tp_size=attn_tp_size,
|
attn_tp_size=attn_tp_size,
|
||||||
local_attn_dp_size=local_attn_dp_size,
|
attn_dp_size=attn_dp_size,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
attn_tp_all_gather(
|
attn_tp_all_gather(
|
||||||
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
||||||
)
|
)
|
||||||
if context.local_attn_dp_size != 1:
|
if context.attn_dp_size != 1:
|
||||||
if context.attn_tp_rank == 0:
|
if context.attn_tp_rank == 0:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
hidden_states, local_hidden_states = (
|
hidden_states, local_hidden_states = (
|
||||||
|
|||||||
@@ -165,7 +165,8 @@ def disable_dp_size():
|
|||||||
|
|
||||||
|
|
||||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||||
dp_rank = get_local_attention_dp_rank()
|
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
||||||
|
dp_rank = get_attention_dp_rank()
|
||||||
|
|
||||||
if forward_batch.dp_local_start_pos is None:
|
if forward_batch.dp_local_start_pos is None:
|
||||||
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
||||||
|
|||||||
@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
attn_tp_all_gather,
|
attn_tp_all_gather,
|
||||||
dp_gather_replicate,
|
dp_gather_replicate,
|
||||||
dp_scatter,
|
dp_scatter,
|
||||||
|
get_attention_dp_rank,
|
||||||
get_attention_dp_size,
|
get_attention_dp_size,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
get_local_attention_dp_rank,
|
|
||||||
get_local_attention_dp_size,
|
get_local_attention_dp_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
@@ -171,7 +171,7 @@ class LogitsMetadata:
|
|||||||
return
|
return
|
||||||
|
|
||||||
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
||||||
dp_rank = get_local_attention_dp_rank()
|
dp_rank = get_attention_dp_rank()
|
||||||
if dp_rank == 0:
|
if dp_rank == 0:
|
||||||
dp_local_start_pos = torch.zeros_like(
|
dp_local_start_pos = torch.zeros_like(
|
||||||
self.global_num_tokens_for_logprob_gpu[0]
|
self.global_num_tokens_for_logprob_gpu[0]
|
||||||
|
|||||||
@@ -149,6 +149,8 @@ from sglang.srt.utils import (
|
|||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
point_to_point_pyobj,
|
point_to_point_pyobj,
|
||||||
pyspy_dump_schedulers,
|
pyspy_dump_schedulers,
|
||||||
|
require_mlp_sync,
|
||||||
|
require_mlp_tp_gather,
|
||||||
set_gpu_proc_affinity,
|
set_gpu_proc_affinity,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
@@ -1471,9 +1473,8 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
ret = None
|
ret = None
|
||||||
|
|
||||||
# Handle DP attention
|
if require_mlp_sync(self.server_args):
|
||||||
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
|
ret, _ = self.prepare_mlp_sync_batch(ret)
|
||||||
ret, _ = self.prepare_dp_attn_batch(ret)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@@ -1775,12 +1776,11 @@ class Scheduler(
|
|||||||
self.return_health_check_ct -= 1
|
self.return_health_check_ct -= 1
|
||||||
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
||||||
|
|
||||||
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
|
||||||
return self.prepare_dp_attn_batch_raw(
|
return self.prepare_mlp_sync_batch_raw(
|
||||||
local_batch,
|
local_batch,
|
||||||
dp_size=self.server_args.dp_size,
|
dp_size=self.server_args.dp_size,
|
||||||
attn_tp_size=self.attn_tp_size,
|
attn_tp_size=self.attn_tp_size,
|
||||||
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
|
||||||
tp_cpu_group=self.tp_cpu_group,
|
tp_cpu_group=self.tp_cpu_group,
|
||||||
get_idle_batch=self.get_idle_batch,
|
get_idle_batch=self.get_idle_batch,
|
||||||
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||||
@@ -1789,14 +1789,14 @@ class Scheduler(
|
|||||||
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
||||||
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
||||||
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
||||||
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_dp_attn_batch_raw(
|
def prepare_mlp_sync_batch_raw(
|
||||||
local_batch: ScheduleBatch,
|
local_batch: ScheduleBatch,
|
||||||
dp_size,
|
dp_size,
|
||||||
attn_tp_size: int,
|
attn_tp_size: int,
|
||||||
moe_dense_tp_size: Optional[int],
|
|
||||||
tp_cpu_group,
|
tp_cpu_group,
|
||||||
get_idle_batch,
|
get_idle_batch,
|
||||||
disable_cuda_graph: bool,
|
disable_cuda_graph: bool,
|
||||||
@@ -1805,6 +1805,7 @@ class Scheduler(
|
|||||||
enable_two_batch_overlap: bool,
|
enable_two_batch_overlap: bool,
|
||||||
enable_deepep_moe: bool,
|
enable_deepep_moe: bool,
|
||||||
deepep_mode: DeepEPMode,
|
deepep_mode: DeepEPMode,
|
||||||
|
require_mlp_tp_gather: bool,
|
||||||
):
|
):
|
||||||
# Check if other DP workers have running batches
|
# Check if other DP workers have running batches
|
||||||
if local_batch is None:
|
if local_batch is None:
|
||||||
@@ -1879,7 +1880,7 @@ class Scheduler(
|
|||||||
|
|
||||||
if local_batch is not None:
|
if local_batch is not None:
|
||||||
# TODO: handle the case when moe_dense_tp_size != 1
|
# TODO: handle the case when moe_dense_tp_size != 1
|
||||||
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
|
if not require_mlp_tp_gather:
|
||||||
local_batch.global_num_tokens = [num_tokens]
|
local_batch.global_num_tokens = [num_tokens]
|
||||||
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ from sglang.srt.utils import (
|
|||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
get_device_memory_capacity,
|
get_device_memory_capacity,
|
||||||
rank0_log,
|
rank0_log,
|
||||||
|
require_attn_tp_gather,
|
||||||
|
require_gathered_buffer,
|
||||||
|
require_mlp_tp_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -207,8 +210,9 @@ class CudaGraphRunner:
|
|||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||||
|
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||||
self.enable_two_batch_overlap = (
|
self.enable_two_batch_overlap = (
|
||||||
model_runner.server_args.enable_two_batch_overlap
|
model_runner.server_args.enable_two_batch_overlap
|
||||||
)
|
)
|
||||||
@@ -299,18 +303,28 @@ class CudaGraphRunner:
|
|||||||
else:
|
else:
|
||||||
self.encoder_lens = None
|
self.encoder_lens = None
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_gathered_buffer:
|
||||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
if self.require_mlp_tp_gather:
|
||||||
self.gathered_buffer = torch.zeros(
|
self.gathered_buffer = torch.zeros(
|
||||||
(
|
(
|
||||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||||
self.model_runner.model_config.hidden_size,
|
self.model_runner.model_config.hidden_size,
|
||||||
),
|
),
|
||||||
dtype=self.model_runner.dtype,
|
dtype=self.model_runner.dtype,
|
||||||
)
|
)
|
||||||
self.global_num_tokens_gpu = torch.zeros(
|
self.global_num_tokens_gpu = torch.zeros(
|
||||||
(self.dp_size,), dtype=torch.int32
|
(self.dp_size,), dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert self.require_attn_tp_gather
|
||||||
|
self.gathered_buffer = torch.zeros(
|
||||||
|
(
|
||||||
|
self.max_bs * self.num_tokens_per_bs,
|
||||||
|
self.model_runner.model_config.hidden_size,
|
||||||
|
),
|
||||||
|
dtype=self.model_runner.dtype,
|
||||||
|
)
|
||||||
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
@@ -322,7 +336,7 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_mlp_tp_gather:
|
||||||
total_batch_size = (
|
total_batch_size = (
|
||||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
if self.model_runner.spec_algorithm.is_eagle()
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
@@ -459,7 +473,7 @@ class CudaGraphRunner:
|
|||||||
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_mlp_tp_gather:
|
||||||
self.global_num_tokens_gpu.copy_(
|
self.global_num_tokens_gpu.copy_(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[
|
[
|
||||||
@@ -472,6 +486,16 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
global_num_tokens = self.global_num_tokens_gpu
|
global_num_tokens = self.global_num_tokens_gpu
|
||||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||||
|
elif self.require_attn_tp_gather:
|
||||||
|
self.global_num_tokens_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[num_tokens],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
global_num_tokens = self.global_num_tokens_gpu
|
||||||
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||||
else:
|
else:
|
||||||
global_num_tokens = None
|
global_num_tokens = None
|
||||||
gathered_buffer = None
|
gathered_buffer = None
|
||||||
@@ -607,7 +631,7 @@ class CudaGraphRunner:
|
|||||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_mlp_tp_gather:
|
||||||
total_batch_size = (
|
total_batch_size = (
|
||||||
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
|
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
|
||||||
if self.model_runner.spec_algorithm.is_eagle()
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
@@ -642,7 +666,7 @@ class CudaGraphRunner:
|
|||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
if forward_batch.mrope_positions is not None:
|
if forward_batch.mrope_positions is not None:
|
||||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_gathered_buffer:
|
||||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||||
if enable_num_token_non_padded(self.model_runner.server_args):
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
||||||
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
||||||
|
|||||||
@@ -1621,8 +1621,6 @@ class DeepseekV2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.dp_size = get_local_attention_dp_size()
|
|
||||||
|
|
||||||
def get_input_embeddings(self) -> torch.Tensor:
|
def get_input_embeddings(self) -> torch.Tensor:
|
||||||
return self.embed_tokens
|
return self.embed_tokens
|
||||||
|
|
||||||
@@ -1706,7 +1704,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.dp_size = get_local_attention_dp_size()
|
|
||||||
|
|
||||||
self._routed_experts_weights_of_layer = LazyValue(
|
self._routed_experts_weights_of_layer = LazyValue(
|
||||||
lambda: {
|
lambda: {
|
||||||
|
|||||||
@@ -387,7 +387,6 @@ class ServerArgs:
|
|||||||
), "Please enable dp attention when setting enable_dp_attention. "
|
), "Please enable dp attention when setting enable_dp_attention. "
|
||||||
|
|
||||||
# DeepEP MoE
|
# DeepEP MoE
|
||||||
self.enable_sp_layernorm = False
|
|
||||||
if self.enable_deepep_moe:
|
if self.enable_deepep_moe:
|
||||||
if self.deepep_mode == "auto":
|
if self.deepep_mode == "auto":
|
||||||
assert (
|
assert (
|
||||||
@@ -397,9 +396,6 @@ class ServerArgs:
|
|||||||
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
||||||
self.disable_cuda_graph = True
|
self.disable_cuda_graph = True
|
||||||
self.ep_size = self.tp_size
|
self.ep_size = self.tp_size
|
||||||
self.enable_sp_layernorm = (
|
|
||||||
self.dp_size < self.tp_size if self.enable_dp_attention else True
|
|
||||||
)
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
require_attn_tp_gather,
|
||||||
|
require_gathered_buffer,
|
||||||
|
require_mlp_tp_gather,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||||
@@ -39,8 +44,9 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||||
|
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||||
self.dp_size = self.model_runner.dp_size
|
self.dp_size = self.model_runner.dp_size
|
||||||
self.tp_size = self.model_runner.tp_size
|
self.tp_size = self.model_runner.tp_size
|
||||||
self.topk = model_runner.server_args.speculative_eagle_topk
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
||||||
@@ -88,8 +94,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
dtype=self.model_runner.dtype,
|
dtype=self.model_runner.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_gathered_buffer:
|
||||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
|
||||||
self.gathered_buffer = torch.zeros(
|
self.gathered_buffer = torch.zeros(
|
||||||
(
|
(
|
||||||
self.max_num_token,
|
self.max_num_token,
|
||||||
@@ -97,12 +102,19 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
),
|
),
|
||||||
dtype=self.model_runner.dtype,
|
dtype=self.model_runner.dtype,
|
||||||
)
|
)
|
||||||
self.global_num_tokens_gpu = torch.zeros(
|
if self.require_mlp_tp_gather:
|
||||||
(self.dp_size,), dtype=torch.int32
|
self.global_num_tokens_gpu = torch.zeros(
|
||||||
)
|
(self.dp_size,), dtype=torch.int32
|
||||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
)
|
||||||
(self.dp_size,), dtype=torch.int32
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||||
)
|
(self.dp_size,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.require_attn_tp_gather
|
||||||
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||||
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||||
|
(1,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
@@ -114,8 +126,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.enable_dp_attention:
|
if self.require_mlp_tp_gather:
|
||||||
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
|
|
||||||
if not forward_batch.can_run_dp_cuda_graph:
|
if not forward_batch.can_run_dp_cuda_graph:
|
||||||
return False
|
return False
|
||||||
total_batch_size = (
|
total_batch_size = (
|
||||||
@@ -153,7 +164,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
topk_index = self.topk_index[:num_seqs]
|
topk_index = self.topk_index[:num_seqs]
|
||||||
hidden_states = self.hidden_states[:num_seqs]
|
hidden_states = self.hidden_states[:num_seqs]
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_mlp_tp_gather:
|
||||||
self.global_num_tokens_gpu.copy_(
|
self.global_num_tokens_gpu.copy_(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[
|
[
|
||||||
@@ -177,6 +188,24 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
global_num_tokens = self.global_num_tokens_gpu
|
global_num_tokens = self.global_num_tokens_gpu
|
||||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||||
|
elif self.require_attn_tp_gather:
|
||||||
|
self.global_num_tokens_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[num_tokens],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[num_tokens],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
global_num_tokens = self.global_num_tokens_gpu
|
||||||
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||||
|
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||||
else:
|
else:
|
||||||
global_num_tokens = None
|
global_num_tokens = None
|
||||||
gathered_buffer = None
|
gathered_buffer = None
|
||||||
@@ -259,7 +288,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_mlp_tp_gather:
|
||||||
total_batch_size = (
|
total_batch_size = (
|
||||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
if self.model_runner.spec_algorithm.is_eagle()
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
@@ -286,7 +315,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
||||||
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_gathered_buffer:
|
||||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
forward_batch.global_num_tokens_for_logprob_gpu
|
forward_batch.global_num_tokens_for_logprob_gpu
|
||||||
|
|||||||
@@ -21,6 +21,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
require_attn_tp_gather,
|
||||||
|
require_gathered_buffer,
|
||||||
|
require_mlp_tp_gather,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||||
@@ -35,8 +40,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.output_buffers = {}
|
self.output_buffers = {}
|
||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||||
|
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||||
self.tp_size = self.model_runner.tp_size
|
self.tp_size = self.model_runner.tp_size
|
||||||
self.dp_size = model_runner.server_args.dp_size
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||||
@@ -92,7 +98,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_gathered_buffer:
|
||||||
self.gathered_buffer = torch.zeros(
|
self.gathered_buffer = torch.zeros(
|
||||||
(
|
(
|
||||||
self.max_num_token,
|
self.max_num_token,
|
||||||
@@ -100,13 +106,19 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
),
|
),
|
||||||
dtype=self.model_runner.dtype,
|
dtype=self.model_runner.dtype,
|
||||||
)
|
)
|
||||||
self.global_num_tokens_gpu = torch.zeros(
|
if self.require_mlp_tp_gather:
|
||||||
(self.dp_size,), dtype=torch.int32
|
self.global_num_tokens_gpu = torch.zeros(
|
||||||
)
|
(self.dp_size,), dtype=torch.int32
|
||||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
)
|
||||||
(self.dp_size,), dtype=torch.int32
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||||
)
|
(self.dp_size,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.require_attn_tp_gather
|
||||||
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||||
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||||
|
(1,), dtype=torch.int32
|
||||||
|
)
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
with model_capture_mode():
|
with model_capture_mode():
|
||||||
@@ -117,7 +129,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_mlp_tp_gather:
|
||||||
if not forward_batch.can_run_dp_cuda_graph:
|
if not forward_batch.can_run_dp_cuda_graph:
|
||||||
return False
|
return False
|
||||||
total_batch_size = (
|
total_batch_size = (
|
||||||
@@ -160,7 +172,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
positions = self.positions[:num_tokens]
|
positions = self.positions[:num_tokens]
|
||||||
hidden_states = self.hidden_states[:num_tokens]
|
hidden_states = self.hidden_states[:num_tokens]
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_mlp_tp_gather:
|
||||||
self.global_num_tokens_gpu.copy_(
|
self.global_num_tokens_gpu.copy_(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[
|
[
|
||||||
@@ -184,6 +196,24 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
global_num_tokens = self.global_num_tokens_gpu
|
global_num_tokens = self.global_num_tokens_gpu
|
||||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||||
|
elif self.require_attn_tp_gather:
|
||||||
|
self.global_num_tokens_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[num_tokens],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[num_tokens],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
global_num_tokens = self.global_num_tokens_gpu
|
||||||
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||||
|
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||||
else:
|
else:
|
||||||
global_num_tokens = None
|
global_num_tokens = None
|
||||||
gathered_buffer = None
|
gathered_buffer = None
|
||||||
@@ -270,7 +300,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
# in the batch, which will not be counted as num_seqs
|
# in the batch, which will not be counted as num_seqs
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
num_tokens = forward_batch.input_ids.shape[0]
|
num_tokens = forward_batch.input_ids.shape[0]
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_mlp_tp_gather:
|
||||||
total_batch_size = (
|
total_batch_size = (
|
||||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
if self.model_runner.spec_algorithm.is_eagle()
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
@@ -299,7 +329,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
||||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.require_gathered_buffer:
|
||||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
forward_batch.global_num_tokens_for_logprob_gpu
|
forward_batch.global_num_tokens_for_logprob_gpu
|
||||||
|
|||||||
@@ -2303,6 +2303,51 @@ class Withable(Generic[T]):
|
|||||||
self._value = None
|
self._value = None
|
||||||
|
|
||||||
|
|
||||||
|
def require_mlp_tp_gather(server_args):
|
||||||
|
"""
|
||||||
|
Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
|
||||||
|
"""
|
||||||
|
if server_args.enable_dp_attention:
|
||||||
|
assert server_args.dp_size > 1, "dp_size must be greater than 1"
|
||||||
|
if (
|
||||||
|
server_args.moe_dense_tp_size is None
|
||||||
|
): # TODO(ch-wan): some MoE models do not have dense layers
|
||||||
|
return True
|
||||||
|
elif not server_args.enable_dp_lm_head:
|
||||||
|
return True
|
||||||
|
elif not server_args.enable_deepep_moe:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
server_args.moe_dense_tp_size
|
||||||
|
> server_args.tp_size // server_args.dp_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def require_attn_tp_gather(server_args):
|
||||||
|
"""
|
||||||
|
Check if the input of attention is scattered.
|
||||||
|
"""
|
||||||
|
assert server_args.moe_dense_tp_size in [1, None]
|
||||||
|
if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
|
||||||
|
if server_args.enable_dp_attention:
|
||||||
|
return server_args.dp_size < server_args.tp_size
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def require_gathered_buffer(server_args):
|
||||||
|
return require_mlp_tp_gather(server_args) or require_attn_tp_gather(server_args)
|
||||||
|
|
||||||
|
|
||||||
|
def require_mlp_sync(server_args):
|
||||||
|
return server_args.enable_dp_attention or require_gathered_buffer(server_args)
|
||||||
|
|
||||||
|
|
||||||
def merge_bias_tensor(
|
def merge_bias_tensor(
|
||||||
lhs: Optional[torch.Tensor],
|
lhs: Optional[torch.Tensor],
|
||||||
rhs: Optional[torch.Tensor],
|
rhs: Optional[torch.Tensor],
|
||||||
|
|||||||
3468
test/srt/test_hybrid_dp_ep_tp_mtp.py
Normal file
3468
test/srt/test_hybrid_dp_ep_tp_mtp.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user