[Feature] Comprehensive Hybrid Parallelism Support (#6389)
This commit is contained in:
@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import require_mlp_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -649,10 +650,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
prepare_dp_attn_flag = (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
)
|
||||
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
@@ -661,14 +659,14 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.stream_output(
|
||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||
)
|
||||
if prepare_dp_attn_flag:
|
||||
if prepare_mlp_sync_flag:
|
||||
self._prepare_idle_batch_and_run(None)
|
||||
else:
|
||||
if prepare_dp_attn_flag:
|
||||
self.prepare_dp_attn_batch(batch)
|
||||
if prepare_mlp_sync_flag:
|
||||
self.prepare_mlp_sync_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
elif prepare_dp_attn_flag:
|
||||
elif prepare_mlp_sync_flag:
|
||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||
|
||||
if batch is None and (
|
||||
@@ -699,10 +697,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.cur_batch = batch
|
||||
last_batch_in_queue = False
|
||||
|
||||
prepare_dp_attn_flag = (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
)
|
||||
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
@@ -711,7 +706,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.stream_output(
|
||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||
)
|
||||
if prepare_dp_attn_flag:
|
||||
if prepare_mlp_sync_flag:
|
||||
batch_, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
@@ -719,8 +714,8 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
result_queue.append((batch_.copy(), result))
|
||||
last_batch_in_queue = True
|
||||
else:
|
||||
if prepare_dp_attn_flag:
|
||||
self.prepare_dp_attn_batch(batch)
|
||||
if prepare_mlp_sync_flag:
|
||||
self.prepare_mlp_sync_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
@@ -735,7 +730,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
last_batch_in_queue = True
|
||||
|
||||
elif prepare_dp_attn_flag:
|
||||
elif prepare_mlp_sync_flag:
|
||||
batch, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
@@ -765,13 +760,13 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.last_batch = batch
|
||||
self.last_batch_in_queue = last_batch_in_queue
|
||||
|
||||
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||
result = None
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
if not delay_process:
|
||||
self.process_batch_result(batch, result)
|
||||
self.prepare_mlp_sync_batch(batch, result)
|
||||
return batch, result
|
||||
|
||||
def get_next_disagg_decode_batch_to_run(
|
||||
|
||||
@@ -45,6 +45,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import require_mlp_sync
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -274,12 +275,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
# Handle DP attention
|
||||
if (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
if require_mlp_sync(self.server_args):
|
||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -312,12 +309,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
# Handle DP attention
|
||||
if (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
if require_mlp_sync(self.server_args):
|
||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||
self.cur_batch = batch
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
|
||||
Reference in New Issue
Block a user