[PP] Add pipeline parallelism (#5724)
This commit is contained in:
@@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank):
|
|||||||
gpu_id=tp_rank,
|
gpu_id=tp_rank,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
tp_size=server_args.tp_size,
|
tp_size=server_args.tp_size,
|
||||||
|
pp_rank=0,
|
||||||
|
pp_size=1,
|
||||||
nccl_port=port_args.nccl_port,
|
nccl_port=port_args.nccl_port,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -126,7 +126,6 @@ class Engine(EngineBase):
|
|||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
port_args=port_args,
|
port_args=port_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.tokenizer_manager = tokenizer_manager
|
self.tokenizer_manager = tokenizer_manager
|
||||||
self.scheduler_info = scheduler_info
|
self.scheduler_info = scheduler_info
|
||||||
@@ -301,7 +300,6 @@ class Engine(EngineBase):
|
|||||||
internal_states = loop.run_until_complete(
|
internal_states = loop.run_until_complete(
|
||||||
self.tokenizer_manager.get_internal_state()
|
self.tokenizer_manager.get_internal_state()
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**dataclasses.asdict(self.tokenizer_manager.server_args),
|
**dataclasses.asdict(self.tokenizer_manager.server_args),
|
||||||
**self.scheduler_info,
|
**self.scheduler_info,
|
||||||
@@ -520,20 +518,39 @@ def _launch_subprocesses(
|
|||||||
)
|
)
|
||||||
|
|
||||||
scheduler_pipe_readers = []
|
scheduler_pipe_readers = []
|
||||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
|
||||||
|
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||||
|
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
||||||
tp_rank_range = range(
|
tp_rank_range = range(
|
||||||
tp_size_per_node * server_args.node_rank,
|
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
||||||
tp_size_per_node * (server_args.node_rank + 1),
|
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
||||||
|
pp_rank_range = range(
|
||||||
|
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
||||||
|
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for pp_rank in pp_rank_range:
|
||||||
for tp_rank in tp_rank_range:
|
for tp_rank in tp_rank_range:
|
||||||
reader, writer = mp.Pipe(duplex=False)
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
gpu_id = (
|
gpu_id = (
|
||||||
server_args.base_gpu_id
|
server_args.base_gpu_id
|
||||||
|
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
||||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||||
)
|
)
|
||||||
proc = mp.Process(
|
proc = mp.Process(
|
||||||
target=run_scheduler_process,
|
target=run_scheduler_process,
|
||||||
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
args=(
|
||||||
|
server_args,
|
||||||
|
port_args,
|
||||||
|
gpu_id,
|
||||||
|
tp_rank,
|
||||||
|
pp_rank,
|
||||||
|
None,
|
||||||
|
writer,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
with memory_saver_adapter.configure_subprocess():
|
with memory_saver_adapter.configure_subprocess():
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ def initialize_dp_attention(
|
|||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
dp_size: int,
|
dp_size: int,
|
||||||
|
pp_size: int,
|
||||||
):
|
):
|
||||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
||||||
|
|
||||||
@@ -53,17 +54,19 @@ def initialize_dp_attention(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if enable_dp_attention:
|
if enable_dp_attention:
|
||||||
|
local_rank = tp_rank % (tp_size // dp_size)
|
||||||
_DP_SIZE = dp_size
|
_DP_SIZE = dp_size
|
||||||
else:
|
else:
|
||||||
|
local_rank = tp_rank
|
||||||
_DP_SIZE = 1
|
_DP_SIZE = 1
|
||||||
|
|
||||||
tp_group = get_tp_group()
|
tp_group = get_tp_group()
|
||||||
_ATTN_TP_GROUP = GroupCoordinator(
|
_ATTN_TP_GROUP = GroupCoordinator(
|
||||||
[
|
[
|
||||||
list(range(head, head + _ATTN_TP_SIZE))
|
list(range(head, head + _ATTN_TP_SIZE))
|
||||||
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
|
||||||
],
|
],
|
||||||
tp_group.local_rank,
|
local_rank,
|
||||||
torch.distributed.get_backend(tp_group.device_group),
|
torch.distributed.get_backend(tp_group.device_group),
|
||||||
SYNC_TOKEN_IDS_ACROSS_TP,
|
SYNC_TOKEN_IDS_ACROSS_TP,
|
||||||
False,
|
False,
|
||||||
|
|||||||
35
python/sglang/srt/layers/utils.py
Normal file
35
python/sglang/srt/layers/utils.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_layer_id(weight_name):
|
||||||
|
# example weight name: model.layers.10.self_attn.qkv_proj.weight
|
||||||
|
match = re.search(r"layers\.(\d+)\.", weight_name)
|
||||||
|
if match:
|
||||||
|
return int(match.group(1))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PPMissingLayer(torch.nn.Identity):
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
||||||
|
"""
|
||||||
|
A placeholder layer for missing layers in a pipeline parallel model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.return_tuple = kwargs.get("return_tuple", False)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Return the first arg from args or the first value from kwargs.
|
||||||
|
|
||||||
|
Wraps the input in a tuple if `self.return_tuple` is True.
|
||||||
|
"""
|
||||||
|
input = args[0] if args else next(iter(kwargs.values()))
|
||||||
|
return (input,) if self.return_tuple else input
|
||||||
@@ -181,13 +181,22 @@ class DataParallelController:
|
|||||||
enable=server_args.enable_memory_saver
|
enable=server_args.enable_memory_saver
|
||||||
)
|
)
|
||||||
|
|
||||||
# Launch tensor parallel scheduler processes
|
|
||||||
scheduler_pipe_readers = []
|
scheduler_pipe_readers = []
|
||||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
|
||||||
|
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||||
|
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
||||||
tp_rank_range = range(
|
tp_rank_range = range(
|
||||||
tp_size_per_node * server_args.node_rank,
|
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
||||||
tp_size_per_node * (server_args.node_rank + 1),
|
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
||||||
|
pp_rank_range = range(
|
||||||
|
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
||||||
|
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for pp_rank in pp_rank_range:
|
||||||
for tp_rank in tp_rank_range:
|
for tp_rank in tp_rank_range:
|
||||||
rank_port_args = port_args
|
rank_port_args = port_args
|
||||||
|
|
||||||
@@ -209,11 +218,20 @@ class DataParallelController:
|
|||||||
gpu_id = (
|
gpu_id = (
|
||||||
server_args.base_gpu_id
|
server_args.base_gpu_id
|
||||||
+ base_gpu_id
|
+ base_gpu_id
|
||||||
|
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
||||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||||
)
|
)
|
||||||
proc = mp.Process(
|
proc = mp.Process(
|
||||||
target=run_scheduler_process,
|
target=run_scheduler_process,
|
||||||
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
args=(
|
||||||
|
server_args,
|
||||||
|
rank_port_args,
|
||||||
|
gpu_id,
|
||||||
|
tp_rank,
|
||||||
|
pp_rank,
|
||||||
|
dp_rank,
|
||||||
|
writer,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
with memory_saver_adapter.configure_subprocess():
|
with memory_saver_adapter.configure_subprocess():
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|||||||
@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
global_server_args_dict = {
|
global_server_args_dict = {
|
||||||
"attention_backend": ServerArgs.attention_backend,
|
"attention_backend": ServerArgs.attention_backend,
|
||||||
"sampling_backend": ServerArgs.sampling_backend,
|
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
||||||
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
|
||||||
"torchao_config": ServerArgs.torchao_config,
|
|
||||||
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
|
||||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
|
||||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
|
||||||
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
|
||||||
"deepep_mode": ServerArgs.deepep_mode,
|
"deepep_mode": ServerArgs.deepep_mode,
|
||||||
"device": ServerArgs.device,
|
"device": ServerArgs.device,
|
||||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
|
||||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
|
||||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
|
||||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
|
||||||
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
|
||||||
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
|
||||||
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
|
||||||
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
|
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
|
||||||
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||||
|
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
||||||
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||||
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||||
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||||
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||||
|
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
||||||
|
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
||||||
|
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
||||||
|
"sampling_backend": ServerArgs.sampling_backend,
|
||||||
|
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||||
|
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||||
|
"torchao_config": ServerArgs.torchao_config,
|
||||||
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# Events
|
# Events
|
||||||
launch_done: Optional[threading.Event] = None
|
launch_done: Optional[threading.Event] = None
|
||||||
|
|
||||||
|
# For chunked prefill in PP
|
||||||
|
chunked_req: Optional[Req] = None
|
||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
sampling_info: SamplingBatchInfo = None
|
sampling_info: SamplingBatchInfo = None
|
||||||
next_batch_sampling_info: SamplingBatchInfo = None
|
next_batch_sampling_info: SamplingBatchInfo = None
|
||||||
@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# For extend and mixed chunekd prefill
|
# For extend and mixed chunekd prefill
|
||||||
prefix_lens: List[int] = None
|
prefix_lens: List[int] = None
|
||||||
extend_lens: List[int] = None
|
extend_lens: List[int] = None
|
||||||
extend_num_tokens: int = None
|
extend_num_tokens: Optional[int] = None
|
||||||
decoding_reqs: List[Req] = None
|
decoding_reqs: List[Req] = None
|
||||||
extend_logprob_start_lens: List[int] = None
|
extend_logprob_start_lens: List[int] = None
|
||||||
# It comes empty list if logprob is not required.
|
# It comes empty list if logprob is not required.
|
||||||
@@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
spec_algorithm: SpeculativeAlgorithm,
|
spec_algorithm: SpeculativeAlgorithm,
|
||||||
enable_custom_logit_processor: bool,
|
enable_custom_logit_processor: bool,
|
||||||
|
chunked_req: Optional[Req] = None,
|
||||||
):
|
):
|
||||||
return_logprob = any(req.return_logprob for req in reqs)
|
return_logprob = any(req.return_logprob for req in reqs)
|
||||||
|
|
||||||
@@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
spec_algorithm=spec_algorithm,
|
spec_algorithm=spec_algorithm,
|
||||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||||
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||||
|
chunked_req=chunked_req,
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
def retract_decode(self, server_args: ServerArgs):
|
def retract_decode(self, server_args: ServerArgs):
|
||||||
"""Retract the decoding requests when there is not enough memory."""
|
"""Retract the decoding requests when there is not enough memory."""
|
||||||
sorted_indices = [i for i in range(len(self.reqs))]
|
sorted_indices = list(range(len(self.reqs)))
|
||||||
|
|
||||||
# TODO(lsyin): improve retraction policy for radix cache
|
# TODO(lsyin): improve retraction policy for radix cache
|
||||||
# For spec decoding, filter_batch API can only filter
|
# For spec decoding, filter_batch API can only filter
|
||||||
@@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
def filter_batch(
|
def filter_batch(
|
||||||
self,
|
self,
|
||||||
chunked_req_to_exclude: Optional[Req] = None,
|
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
|
||||||
keep_indices: Optional[List[int]] = None,
|
keep_indices: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
if keep_indices is None:
|
if keep_indices is None:
|
||||||
|
if isinstance(chunked_req_to_exclude, Req):
|
||||||
|
chunked_req_to_exclude = [chunked_req_to_exclude]
|
||||||
|
elif chunked_req_to_exclude is None:
|
||||||
|
chunked_req_to_exclude = []
|
||||||
keep_indices = [
|
keep_indices = [
|
||||||
i
|
i
|
||||||
for i in range(len(self.reqs))
|
for i in range(len(self.reqs))
|
||||||
if not self.reqs[i].finished()
|
if not self.reqs[i].finished()
|
||||||
and self.reqs[i] is not chunked_req_to_exclude
|
and not self.reqs[i] in chunked_req_to_exclude
|
||||||
]
|
]
|
||||||
|
|
||||||
if keep_indices is None or len(keep_indices) == 0:
|
if keep_indices is None or len(keep_indices) == 0:
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
TransferBackend,
|
TransferBackend,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
@@ -114,7 +115,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|||||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
PPProxyTensors,
|
||||||
|
)
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
@@ -127,6 +132,7 @@ from sglang.srt.utils import (
|
|||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
|
point_to_point_pyobj,
|
||||||
pyspy_dump_schedulers,
|
pyspy_dump_schedulers,
|
||||||
set_gpu_proc_affinity,
|
set_gpu_proc_affinity,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
@@ -145,8 +151,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerationBatchResult:
|
class GenerationBatchResult:
|
||||||
logits_output: LogitsProcessorOutput
|
logits_output: Optional[LogitsProcessorOutput]
|
||||||
next_token_ids: List[int]
|
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
|
||||||
|
next_token_ids: Optional[List[int]]
|
||||||
extend_input_len_per_req: List[int]
|
extend_input_len_per_req: List[int]
|
||||||
extend_logprob_start_len_per_req: List[int]
|
extend_logprob_start_len_per_req: List[int]
|
||||||
bid: int
|
bid: int
|
||||||
@@ -171,12 +178,16 @@ class Scheduler(
|
|||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
pp_rank: int,
|
||||||
dp_rank: Optional[int],
|
dp_rank: Optional[int],
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
|
self.pp_rank = pp_rank
|
||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
|
self.pp_size = server_args.pp_size
|
||||||
|
self.dp_size = server_args.dp_size
|
||||||
self.schedule_policy = server_args.schedule_policy
|
self.schedule_policy = server_args.schedule_policy
|
||||||
self.lora_paths = server_args.lora_paths
|
self.lora_paths = server_args.lora_paths
|
||||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||||
@@ -192,7 +203,6 @@ class Scheduler(
|
|||||||
self.page_size = server_args.page_size
|
self.page_size = server_args.page_size
|
||||||
|
|
||||||
# Distributed rank info
|
# Distributed rank info
|
||||||
self.dp_size = server_args.dp_size
|
|
||||||
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
||||||
compute_dp_attention_world_info(
|
compute_dp_attention_world_info(
|
||||||
server_args.enable_dp_attention,
|
server_args.enable_dp_attention,
|
||||||
@@ -204,7 +214,7 @@ class Scheduler(
|
|||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
if self.attn_tp_rank == 0:
|
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
||||||
self.recv_from_tokenizer = get_zmq_socket(
|
self.recv_from_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||||
)
|
)
|
||||||
@@ -259,6 +269,7 @@ class Scheduler(
|
|||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
|
pp_rank=pp_rank,
|
||||||
dp_rank=dp_rank,
|
dp_rank=dp_rank,
|
||||||
nccl_port=port_args.nccl_port,
|
nccl_port=port_args.nccl_port,
|
||||||
)
|
)
|
||||||
@@ -292,8 +303,18 @@ class Scheduler(
|
|||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
) = self.tp_worker.get_worker_info()
|
) = self.tp_worker.get_worker_info()
|
||||||
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
if global_server_args_dict["max_micro_batch_size"] is None:
|
||||||
|
global_server_args_dict["max_micro_batch_size"] = max(
|
||||||
|
self.max_running_requests // server_args.pp_size, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tp_group = self.tp_worker.get_tp_group()
|
||||||
|
self.tp_cpu_group = self.tp_group.cpu_group
|
||||||
|
self.attn_tp_group = self.tp_worker.get_attention_tp_group()
|
||||||
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
self.world_group = get_world_group()
|
||||||
|
|
||||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||||
global_server_args_dict.update(worker_global_server_args_dict)
|
global_server_args_dict.update(worker_global_server_args_dict)
|
||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
@@ -673,8 +694,111 @@ class Scheduler(
|
|||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
|
@DynamicGradMode()
|
||||||
|
def event_loop_pp(self):
|
||||||
|
"""A non-overlap scheduler loop for pipeline parallelism."""
|
||||||
|
mbs = [None] * self.pp_size
|
||||||
|
last_mbs = [None] * self.pp_size
|
||||||
|
self.running_mbs = [
|
||||||
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||||
|
]
|
||||||
|
bids = [None] * self.pp_size
|
||||||
|
pp_outputs: Optional[PPProxyTensors] = None
|
||||||
|
while True:
|
||||||
|
server_is_idle = True
|
||||||
|
for mb_id in range(self.pp_size):
|
||||||
|
self.running_batch = self.running_mbs[mb_id]
|
||||||
|
self.last_batch = last_mbs[mb_id]
|
||||||
|
|
||||||
|
recv_reqs = self.recv_requests()
|
||||||
|
self.process_input_requests(recv_reqs)
|
||||||
|
mbs[mb_id] = self.get_next_batch_to_run()
|
||||||
|
self.running_mbs[mb_id] = self.running_batch
|
||||||
|
|
||||||
|
self.cur_batch = mbs[mb_id]
|
||||||
|
if self.cur_batch:
|
||||||
|
server_is_idle = False
|
||||||
|
result = self.run_batch(self.cur_batch)
|
||||||
|
|
||||||
|
# send the outputs to the next step
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
if self.cur_batch:
|
||||||
|
next_token_ids, bids[mb_id] = (
|
||||||
|
result.next_token_ids,
|
||||||
|
result.bid,
|
||||||
|
)
|
||||||
|
pp_outputs = PPProxyTensors(
|
||||||
|
{
|
||||||
|
"next_token_ids": next_token_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# send the output from the last round to let the next stage worker run post processing
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
pp_outputs.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||||
|
next_mb_id = (mb_id + 1) % self.pp_size
|
||||||
|
next_pp_outputs = None
|
||||||
|
if mbs[next_mb_id] is not None:
|
||||||
|
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||||
|
self.pp_group.recv_tensor_dict(
|
||||||
|
all_gather_group=self.attn_tp_group
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||||
|
output_result = GenerationBatchResult(
|
||||||
|
logits_output=None,
|
||||||
|
pp_hidden_states_proxy_tensors=None,
|
||||||
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||||
|
extend_input_len_per_req=None,
|
||||||
|
extend_logprob_start_len_per_req=None,
|
||||||
|
bid=bids[next_mb_id],
|
||||||
|
)
|
||||||
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||||
|
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||||
|
|
||||||
|
# carry the outputs to the next stage
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
if self.cur_batch:
|
||||||
|
bids[mb_id] = result.bid
|
||||||
|
if pp_outputs:
|
||||||
|
# send the outputs from the last round to let the next stage worker run post processing
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
pp_outputs.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
# send out reqs to the next stage
|
||||||
|
dp_offset = self.dp_rank * self.attn_tp_size
|
||||||
|
if self.attn_tp_rank == 0:
|
||||||
|
point_to_point_pyobj(
|
||||||
|
recv_reqs,
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
self.world_group.cpu_group,
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# send out proxy tensors to the next stage
|
||||||
|
if self.cur_batch:
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
result.pp_hidden_states_proxy_tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
pp_outputs = next_pp_outputs
|
||||||
|
|
||||||
|
# When the server is idle, self-check and re-init some states
|
||||||
|
if server_is_idle:
|
||||||
|
self.check_memory()
|
||||||
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
def recv_requests(self) -> List[Req]:
|
def recv_requests(self) -> List[Req]:
|
||||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||||
|
if self.pp_rank == 0:
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
recv_reqs = []
|
recv_reqs = []
|
||||||
|
|
||||||
@@ -693,6 +817,18 @@ class Scheduler(
|
|||||||
recv_reqs.append(recv_rpc)
|
recv_reqs.append(recv_rpc)
|
||||||
else:
|
else:
|
||||||
recv_reqs = None
|
recv_reqs = None
|
||||||
|
else:
|
||||||
|
if self.attn_tp_rank == 0:
|
||||||
|
dp_offset = self.dp_rank * self.attn_tp_size
|
||||||
|
recv_reqs = point_to_point_pyobj(
|
||||||
|
[],
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
self.world_group.cpu_group,
|
||||||
|
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
recv_reqs = None
|
||||||
|
|
||||||
if self.server_args.enable_dp_attention:
|
if self.server_args.enable_dp_attention:
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
@@ -715,20 +851,27 @@ class Scheduler(
|
|||||||
control_reqs = None
|
control_reqs = None
|
||||||
|
|
||||||
if self.attn_tp_size != 1:
|
if self.attn_tp_size != 1:
|
||||||
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
|
||||||
work_reqs = broadcast_pyobj(
|
work_reqs = broadcast_pyobj(
|
||||||
work_reqs,
|
work_reqs,
|
||||||
self.attn_tp_rank,
|
self.attn_tp_group.rank,
|
||||||
self.attn_tp_cpu_group,
|
self.attn_tp_cpu_group,
|
||||||
src=attn_tp_rank_0,
|
src=self.attn_tp_group.ranks[0],
|
||||||
)
|
)
|
||||||
if self.tp_size != 1:
|
if self.tp_size != 1:
|
||||||
control_reqs = broadcast_pyobj(
|
control_reqs = broadcast_pyobj(
|
||||||
control_reqs, self.tp_rank, self.tp_cpu_group
|
control_reqs,
|
||||||
|
self.tp_group.rank,
|
||||||
|
self.tp_cpu_group,
|
||||||
|
src=self.tp_group.ranks[0],
|
||||||
)
|
)
|
||||||
recv_reqs = work_reqs + control_reqs
|
recv_reqs = work_reqs + control_reqs
|
||||||
elif self.tp_size != 1:
|
elif self.tp_size != 1:
|
||||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
recv_reqs = broadcast_pyobj(
|
||||||
|
recv_reqs,
|
||||||
|
self.tp_group.rank,
|
||||||
|
self.tp_cpu_group,
|
||||||
|
src=self.tp_group.ranks[0],
|
||||||
|
)
|
||||||
return recv_reqs
|
return recv_reqs
|
||||||
|
|
||||||
def process_input_requests(self, recv_reqs: List):
|
def process_input_requests(self, recv_reqs: List):
|
||||||
@@ -1026,12 +1169,14 @@ class Scheduler(
|
|||||||
|
|
||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
|
||||||
def log_decode_stats(self):
|
def log_decode_stats(self, running_batch=None):
|
||||||
|
batch = running_batch or self.running_batch
|
||||||
|
|
||||||
gap_latency = time.time() - self.last_decode_stats_tic
|
gap_latency = time.time() - self.last_decode_stats_tic
|
||||||
self.last_decode_stats_tic = time.time()
|
self.last_decode_stats_tic = time.time()
|
||||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
num_running_reqs = len(self.running_batch.reqs)
|
num_running_reqs = len(batch.reqs)
|
||||||
num_used = self.max_total_num_tokens - (
|
num_used = self.max_total_num_tokens - (
|
||||||
self.token_to_kv_pool_allocator.available_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
+ self.tree_cache.evictable_size()
|
+ self.tree_cache.evictable_size()
|
||||||
@@ -1131,19 +1276,25 @@ class Scheduler(
|
|||||||
|
|
||||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||||
# Merge the prefill batch into the running batch
|
# Merge the prefill batch into the running batch
|
||||||
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
chunked_req_to_exclude = set()
|
||||||
if self.chunked_req:
|
if self.chunked_req:
|
||||||
# Move the chunked request out of the batch so that we can merge
|
# Move the chunked request out of the batch so that we can merge
|
||||||
# only finished requests to running_batch.
|
# only finished requests to running_batch.
|
||||||
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
chunked_req_to_exclude.add(self.chunked_req)
|
||||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||||
# chunked request keeps its rid but will get a new req_pool_idx
|
# chunked request keeps its rid but will get a new req_pool_idx
|
||||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||||
self.running_batch.batch_is_full = False
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||||
|
if self.last_batch.chunked_req is not None:
|
||||||
|
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
||||||
|
# We need to discard it.
|
||||||
|
chunked_req_to_exclude.add(self.last_batch.chunked_req)
|
||||||
|
|
||||||
# Filter batch
|
# Filter batch
|
||||||
last_bs = self.last_batch.batch_size()
|
last_bs = self.last_batch.batch_size()
|
||||||
self.last_batch.filter_batch()
|
self.last_batch.filter_batch(
|
||||||
|
chunked_req_to_exclude=list(chunked_req_to_exclude)
|
||||||
|
)
|
||||||
if self.last_batch.batch_size() < last_bs:
|
if self.last_batch.batch_size() < last_bs:
|
||||||
self.running_batch.batch_is_full = False
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
@@ -1173,6 +1324,12 @@ class Scheduler(
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def get_num_allocatable_reqs(self, running_bs):
|
||||||
|
res = global_server_args_dict["max_micro_batch_size"] - running_bs
|
||||||
|
if self.pp_size > 1:
|
||||||
|
res = min(res, self.req_to_token_pool.available_size())
|
||||||
|
return res
|
||||||
|
|
||||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||||
# Check if the grammar is ready in the grammar queue
|
# Check if the grammar is ready in the grammar queue
|
||||||
if self.grammar_queue:
|
if self.grammar_queue:
|
||||||
@@ -1185,7 +1342,12 @@ class Scheduler(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
running_bs = len(self.running_batch.reqs)
|
running_bs = len(self.running_batch.reqs)
|
||||||
if running_bs >= self.max_running_requests:
|
# Igore the check if self.chunked_req is not None.
|
||||||
|
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
|
||||||
|
# as the space for the chunked request has just been released.
|
||||||
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
||||||
|
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
||||||
|
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
|
||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -1229,7 +1391,7 @@ class Scheduler(
|
|||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -1241,16 +1403,14 @@ class Scheduler(
|
|||||||
res = adder.add_one_req(
|
res = adder.add_one_req(
|
||||||
req, self.chunked_req, self.enable_hierarchical_cache
|
req, self.chunked_req, self.enable_hierarchical_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
if res != AddReqResult.CONTINUE:
|
if res != AddReqResult.CONTINUE:
|
||||||
if res == AddReqResult.NO_TOKEN:
|
if res == AddReqResult.NO_TOKEN:
|
||||||
if self.enable_hierarchical_cache:
|
if self.enable_hierarchical_cache:
|
||||||
# Set batch_is_full after making sure there are requests that can be served
|
# Set batch_is_full after making sure there are requests that can be served
|
||||||
self.running_batch.batch_is_full = len(
|
self.running_batch.batch_is_full = len(
|
||||||
adder.can_run_list
|
adder.can_run_list
|
||||||
) > 0 or (
|
) > 0 or (not self.running_batch.is_empty())
|
||||||
self.running_batch is not None
|
|
||||||
and not self.running_batch.is_empty()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
break
|
break
|
||||||
@@ -1293,6 +1453,7 @@ class Scheduler(
|
|||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
self.server_args.enable_custom_logit_processor,
|
||||||
|
chunked_req=self.chunked_req,
|
||||||
)
|
)
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
@@ -1370,8 +1531,13 @@ class Scheduler(
|
|||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if self.spec_algorithm.is_none():
|
if self.spec_algorithm.is_none():
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
if self.pp_group.is_last_rank:
|
||||||
model_worker_batch
|
logits_output, next_token_ids = (
|
||||||
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pp_hidden_states_proxy_tensors, _ = (
|
||||||
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||||
)
|
)
|
||||||
bid = model_worker_batch.bid
|
bid = model_worker_batch.bid
|
||||||
else:
|
else:
|
||||||
@@ -1386,6 +1552,8 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
self.spec_num_total_forward_ct += batch.batch_size()
|
self.spec_num_total_forward_ct += batch.batch_size()
|
||||||
self.num_generated_tokens += num_accepted_tokens
|
self.num_generated_tokens += num_accepted_tokens
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
batch.output_ids = next_token_ids
|
batch.output_ids = next_token_ids
|
||||||
|
|
||||||
# These 2 values are needed for processing the output, but the values can be
|
# These 2 values are needed for processing the output, but the values can be
|
||||||
@@ -1401,8 +1569,13 @@ class Scheduler(
|
|||||||
extend_logprob_start_len_per_req = None
|
extend_logprob_start_len_per_req = None
|
||||||
|
|
||||||
ret = GenerationBatchResult(
|
ret = GenerationBatchResult(
|
||||||
logits_output=logits_output,
|
logits_output=logits_output if self.pp_group.is_last_rank else None,
|
||||||
next_token_ids=next_token_ids,
|
pp_hidden_states_proxy_tensors=(
|
||||||
|
pp_hidden_states_proxy_tensors
|
||||||
|
if not self.pp_group.is_last_rank
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
||||||
extend_input_len_per_req=extend_input_len_per_req,
|
extend_input_len_per_req=extend_input_len_per_req,
|
||||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||||
bid=bid,
|
bid=bid,
|
||||||
@@ -1553,6 +1726,7 @@ class Scheduler(
|
|||||||
|
|
||||||
def move_ready_grammar_requests(self):
|
def move_ready_grammar_requests(self):
|
||||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||||
|
|
||||||
num_ready_reqs = 0
|
num_ready_reqs = 0
|
||||||
for req in self.grammar_queue:
|
for req in self.grammar_queue:
|
||||||
try:
|
try:
|
||||||
@@ -1619,7 +1793,11 @@ class Scheduler(
|
|||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
"""Flush the memory pool and cache."""
|
"""Flush the memory pool and cache."""
|
||||||
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
|
if (
|
||||||
|
len(self.waiting_queue) == 0
|
||||||
|
and self.running_batch.is_empty()
|
||||||
|
and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
|
||||||
|
):
|
||||||
self.cur_batch = None
|
self.cur_batch = None
|
||||||
self.last_batch = None
|
self.last_batch = None
|
||||||
self.tree_cache.reset()
|
self.tree_cache.reset()
|
||||||
@@ -1657,7 +1835,6 @@ class Scheduler(
|
|||||||
ret["avg_spec_accept_length"] = (
|
ret["avg_spec_accept_length"] = (
|
||||||
self.cum_spec_accept_length / self.cum_spec_accept_count
|
self.cum_spec_accept_length / self.cum_spec_accept_count
|
||||||
)
|
)
|
||||||
|
|
||||||
if RECORD_STEP_TIME:
|
if RECORD_STEP_TIME:
|
||||||
ret["step_time_dict"] = self.step_time_dict
|
ret["step_time_dict"] = self.step_time_dict
|
||||||
return GetInternalStateReqOutput(
|
return GetInternalStateReqOutput(
|
||||||
@@ -1668,6 +1845,7 @@ class Scheduler(
|
|||||||
server_args_dict = recv_req.server_args
|
server_args_dict = recv_req.server_args
|
||||||
args_allow_update = set(
|
args_allow_update = set(
|
||||||
[
|
[
|
||||||
|
"max_micro_batch_size",
|
||||||
"speculative_accept_threshold_single",
|
"speculative_accept_threshold_single",
|
||||||
"speculative_accept_threshold_acc",
|
"speculative_accept_threshold_acc",
|
||||||
]
|
]
|
||||||
@@ -1678,6 +1856,14 @@ class Scheduler(
|
|||||||
logging.warning(f"Updating {k} is not supported.")
|
logging.warning(f"Updating {k} is not supported.")
|
||||||
if_success = False
|
if_success = False
|
||||||
break
|
break
|
||||||
|
elif k == "max_micro_batch_size" and (
|
||||||
|
v > self.max_running_requests // self.pp_size or v < 1
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
|
||||||
|
)
|
||||||
|
if_success = False
|
||||||
|
break
|
||||||
if if_success:
|
if if_success:
|
||||||
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
||||||
avg_spec_accept_length = (
|
avg_spec_accept_length = (
|
||||||
@@ -1959,6 +2145,16 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
del self.sessions[session_id]
|
del self.sessions[session_id]
|
||||||
|
|
||||||
|
def get_print_prefix(self):
|
||||||
|
prefix = ""
|
||||||
|
if self.dp_rank is not None:
|
||||||
|
prefix += f" DP{self.dp_rank}"
|
||||||
|
if self.server_args.tp_size > 1:
|
||||||
|
prefix += f" TP{self.tp_rank}"
|
||||||
|
if self.pp_size > 1:
|
||||||
|
prefix += f" PP{self.pp_rank}"
|
||||||
|
return prefix
|
||||||
|
|
||||||
|
|
||||||
def is_health_check_generate_req(recv_req):
|
def is_health_check_generate_req(recv_req):
|
||||||
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
||||||
@@ -1983,14 +2179,18 @@ def run_scheduler_process(
|
|||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
pp_rank: int,
|
||||||
dp_rank: Optional[int],
|
dp_rank: Optional[int],
|
||||||
pipe_writer,
|
pipe_writer,
|
||||||
):
|
):
|
||||||
# Generate the prefix
|
# Generate the prefix
|
||||||
if dp_rank is None:
|
prefix = ""
|
||||||
prefix = f" TP{tp_rank}"
|
if dp_rank is not None:
|
||||||
else:
|
prefix += f" DP{dp_rank}"
|
||||||
prefix = f" DP{dp_rank} TP{tp_rank}"
|
if server_args.tp_size > 1:
|
||||||
|
prefix += f" TP{tp_rank}"
|
||||||
|
if server_args.pp_size > 1:
|
||||||
|
prefix += f" PP{pp_rank}"
|
||||||
|
|
||||||
# Config the process
|
# Config the process
|
||||||
kill_itself_when_parent_died()
|
kill_itself_when_parent_died()
|
||||||
@@ -2012,7 +2212,7 @@ def run_scheduler_process(
|
|||||||
|
|
||||||
# Create a scheduler and run the event loop
|
# Create a scheduler and run the event loop
|
||||||
try:
|
try:
|
||||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
||||||
pipe_writer.send(
|
pipe_writer.send(
|
||||||
{
|
{
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
@@ -2023,7 +2223,9 @@ def run_scheduler_process(
|
|||||||
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
||||||
|
|
||||||
if disaggregation_mode == DisaggregationMode.NULL:
|
if disaggregation_mode == DisaggregationMode.NULL:
|
||||||
if scheduler.enable_overlap:
|
if server_args.pp_size > 1:
|
||||||
|
scheduler.event_loop_pp()
|
||||||
|
elif scheduler.enable_overlap:
|
||||||
scheduler.event_loop_overlap()
|
scheduler.event_loop_overlap()
|
||||||
else:
|
else:
|
||||||
scheduler.event_loop_normal()
|
scheduler.event_loop_normal()
|
||||||
@@ -2032,6 +2234,7 @@ def run_scheduler_process(
|
|||||||
scheduler.event_loop_overlap_disagg_prefill()
|
scheduler.event_loop_overlap_disagg_prefill()
|
||||||
else:
|
else:
|
||||||
scheduler.event_loop_normal_disagg_prefill()
|
scheduler.event_loop_normal_disagg_prefill()
|
||||||
|
|
||||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
if scheduler.enable_overlap:
|
if scheduler.enable_overlap:
|
||||||
scheduler.event_loop_overlap_disagg_decode()
|
scheduler.event_loop_overlap_disagg_decode()
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
self.attn_tp_rank == 0
|
self.attn_tp_rank == 0
|
||||||
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
||||||
):
|
):
|
||||||
self.log_decode_stats()
|
self.log_decode_stats(running_batch=batch)
|
||||||
|
|
||||||
def add_input_logprob_return_values(
|
def add_input_logprob_return_values(
|
||||||
self: Scheduler,
|
self: Scheduler,
|
||||||
|
|||||||
@@ -15,11 +15,12 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
||||||
@@ -47,6 +48,7 @@ class TpModelWorker:
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
pp_rank: int,
|
||||||
dp_rank: Optional[int],
|
dp_rank: Optional[int],
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
is_draft_worker: bool = False,
|
is_draft_worker: bool = False,
|
||||||
@@ -54,7 +56,9 @@ class TpModelWorker:
|
|||||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
|
self.tp_size = server_args.tp_size
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
|
self.pp_rank = pp_rank
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
@@ -73,12 +77,15 @@ class TpModelWorker:
|
|||||||
quantization=server_args.quantization,
|
quantization=server_args.quantization,
|
||||||
is_draft_model=is_draft_worker,
|
is_draft_model=is_draft_worker,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
mem_fraction_static=server_args.mem_fraction_static,
|
mem_fraction_static=server_args.mem_fraction_static,
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
tp_size=server_args.tp_size,
|
tp_size=server_args.tp_size,
|
||||||
|
pp_rank=pp_rank,
|
||||||
|
pp_size=server_args.pp_size,
|
||||||
nccl_port=nccl_port,
|
nccl_port=nccl_port,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
is_draft_worker=is_draft_worker,
|
is_draft_worker=is_draft_worker,
|
||||||
@@ -105,6 +112,10 @@ class TpModelWorker:
|
|||||||
)
|
)
|
||||||
self.device = self.model_runner.device
|
self.device = self.model_runner.device
|
||||||
|
|
||||||
|
# Init nccl groups
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
self.world_group = get_world_group()
|
||||||
|
|
||||||
# Profile number of tokens
|
# Profile number of tokens
|
||||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||||
self.max_prefill_tokens = server_args.max_prefill_tokens
|
self.max_prefill_tokens = server_args.max_prefill_tokens
|
||||||
@@ -130,8 +141,9 @@ class TpModelWorker:
|
|||||||
# Sync random seed across TP workers
|
# Sync random seed across TP workers
|
||||||
self.random_seed = broadcast_pyobj(
|
self.random_seed = broadcast_pyobj(
|
||||||
[server_args.random_seed],
|
[server_args.random_seed],
|
||||||
self.tp_rank,
|
self.tp_size * self.pp_rank + tp_rank,
|
||||||
self.model_runner.tp_group.cpu_group,
|
self.world_group.cpu_group,
|
||||||
|
src=self.world_group.ranks[0],
|
||||||
)[0]
|
)[0]
|
||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
|
|
||||||
@@ -156,11 +168,14 @@ class TpModelWorker:
|
|||||||
def get_pad_input_ids_func(self):
|
def get_pad_input_ids_func(self):
|
||||||
return getattr(self.model_runner.model, "pad_input_ids", None)
|
return getattr(self.model_runner.model, "pad_input_ids", None)
|
||||||
|
|
||||||
def get_tp_cpu_group(self):
|
def get_tp_group(self):
|
||||||
return self.model_runner.tp_group.cpu_group
|
return self.model_runner.tp_group
|
||||||
|
|
||||||
|
def get_attention_tp_group(self):
|
||||||
|
return self.model_runner.attention_tp_group
|
||||||
|
|
||||||
def get_attention_tp_cpu_group(self):
|
def get_attention_tp_cpu_group(self):
|
||||||
return self.model_runner.attention_tp_group.cpu_group
|
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
||||||
|
|
||||||
def get_memory_pool(self):
|
def get_memory_pool(self):
|
||||||
return (
|
return (
|
||||||
@@ -172,19 +187,38 @@ class TpModelWorker:
|
|||||||
self,
|
self,
|
||||||
model_worker_batch: ModelWorkerBatch,
|
model_worker_batch: ModelWorkerBatch,
|
||||||
skip_sample: bool = False,
|
skip_sample: bool = False,
|
||||||
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
|
||||||
|
|
||||||
|
pp_proxy_tensors = None
|
||||||
|
if not self.pp_group.is_first_rank:
|
||||||
|
pp_proxy_tensors = PPProxyTensors(
|
||||||
|
self.pp_group.recv_tensor_dict(
|
||||||
|
all_gather_group=self.get_attention_tp_group()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
logits_output = self.model_runner.forward(
|
||||||
|
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
||||||
|
)
|
||||||
if model_worker_batch.launch_done is not None:
|
if model_worker_batch.launch_done is not None:
|
||||||
model_worker_batch.launch_done.set()
|
model_worker_batch.launch_done.set()
|
||||||
|
|
||||||
if skip_sample:
|
if skip_sample:
|
||||||
next_token_ids = None
|
next_token_ids = None
|
||||||
else:
|
else:
|
||||||
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
next_token_ids = self.model_runner.sample(
|
||||||
|
logits_output, model_worker_batch
|
||||||
|
)
|
||||||
|
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
else:
|
||||||
|
pp_proxy_tensors = self.model_runner.forward(
|
||||||
|
forward_batch,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
|
)
|
||||||
|
return pp_proxy_tensors.tensors, None
|
||||||
|
|
||||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
|||||||
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
pp_rank: int,
|
||||||
dp_rank: Optional[int],
|
dp_rank: Optional[int],
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
):
|
):
|
||||||
# Load the model
|
# Load the model
|
||||||
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
|
self.worker = TpModelWorker(
|
||||||
|
server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
|
||||||
|
)
|
||||||
self.max_running_requests = self.worker.max_running_requests
|
self.max_running_requests = self.worker.max_running_requests
|
||||||
self.device = self.worker.device
|
self.device = self.worker.device
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
|
|||||||
def get_pad_input_ids_func(self):
|
def get_pad_input_ids_func(self):
|
||||||
return self.worker.get_pad_input_ids_func()
|
return self.worker.get_pad_input_ids_func()
|
||||||
|
|
||||||
def get_tp_cpu_group(self):
|
def get_tp_group(self):
|
||||||
return self.worker.get_tp_cpu_group()
|
return self.worker.get_tp_group()
|
||||||
|
|
||||||
|
def get_attention_tp_group(self):
|
||||||
|
return self.worker.get_attention_tp_group()
|
||||||
|
|
||||||
def get_attention_tp_cpu_group(self):
|
def get_attention_tp_cpu_group(self):
|
||||||
return self.worker.get_attention_tp_cpu_group()
|
return self.worker.get_attention_tp_cpu_group()
|
||||||
|
|||||||
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
|
|||||||
layer_num: int,
|
layer_num: int,
|
||||||
device: str,
|
device: str,
|
||||||
enable_memory_saver: bool,
|
enable_memory_saver: bool,
|
||||||
|
start_layer: Optional[int] = None,
|
||||||
|
end_layer: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.layer_num = layer_num
|
self.layer_num = layer_num
|
||||||
self._create_buffers()
|
self._create_buffers()
|
||||||
|
self.start_layer = start_layer or 0
|
||||||
|
self.end_layer = end_layer or layer_num - 1
|
||||||
|
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
self.capture_mode = False
|
self.capture_mode = False
|
||||||
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
|
|||||||
|
|
||||||
# for disagg
|
# for disagg
|
||||||
def get_contiguous_buf_infos(self):
|
def get_contiguous_buf_infos(self):
|
||||||
|
# layer_num x [seq_len, head_num, head_dim]
|
||||||
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
||||||
kv_data_ptrs = [
|
kv_data_ptrs = [
|
||||||
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
||||||
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
||||||
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
|
|||||||
# transfer prepared data from host to device
|
# transfer prepared data from host to device
|
||||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||||
k_data, v_data = flat_data[0], flat_data[1]
|
k_data, v_data = flat_data[0], flat_data[1]
|
||||||
self.k_buffer[layer_id][indices] = k_data
|
self.k_buffer[layer_id - self.start_layer][indices] = k_data
|
||||||
self.v_buffer[layer_id][indices] = v_data
|
self.v_buffer[layer_id - self.start_layer][indices] = v_data
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
if self.layer_transfer_counter is not None:
|
if self.layer_transfer_counter is not None:
|
||||||
self.layer_transfer_counter.wait_until(layer_id)
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
return self.k_buffer[layer_id].view(self.dtype)
|
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||||
return self.k_buffer[layer_id]
|
return self.k_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
def get_value_buffer(self, layer_id: int):
|
def get_value_buffer(self, layer_id: int):
|
||||||
if self.layer_transfer_counter is not None:
|
if self.layer_transfer_counter is not None:
|
||||||
self.layer_transfer_counter.wait_until(layer_id)
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
return self.v_buffer[layer_id].view(self.dtype)
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||||
return self.v_buffer[layer_id]
|
return self.v_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
def get_kv_buffer(self, layer_id: int):
|
def get_kv_buffer(self, layer_id: int):
|
||||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||||
@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
|
|||||||
current_stream = self.device_module.current_stream()
|
current_stream = self.device_module.current_stream()
|
||||||
self.alt_stream.wait_stream(current_stream)
|
self.alt_stream.wait_stream(current_stream)
|
||||||
with self.device_module.stream(self.alt_stream):
|
with self.device_module.stream(self.alt_stream):
|
||||||
self.k_buffer[layer_id][loc] = cache_k
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||||
self.v_buffer[layer_id][loc] = cache_v
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||||
current_stream.wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
else:
|
else:
|
||||||
self.k_buffer[layer_id][loc] = cache_k
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||||
self.v_buffer[layer_id][loc] = cache_v
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
|
|||||||
layer_num: int,
|
layer_num: int,
|
||||||
device: str,
|
device: str,
|
||||||
enable_memory_saver: bool,
|
enable_memory_saver: bool,
|
||||||
|
start_layer: Optional[int] = None,
|
||||||
|
end_layer: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
|
|||||||
self.kv_lora_rank = kv_lora_rank
|
self.kv_lora_rank = kv_lora_rank
|
||||||
self.qk_rope_head_dim = qk_rope_head_dim
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
self.layer_num = layer_num
|
self.layer_num = layer_num
|
||||||
|
self.start_layer = start_layer or 0
|
||||||
|
self.end_layer = end_layer or layer_num - 1
|
||||||
|
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
enable=enable_memory_saver
|
enable=enable_memory_saver
|
||||||
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
|
|||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
if self.layer_transfer_counter is not None:
|
if self.layer_transfer_counter is not None:
|
||||||
self.layer_transfer_counter.wait_until(layer_id)
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
return self.kv_buffer[layer_id].view(self.dtype)
|
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||||
return self.kv_buffer[layer_id]
|
return self.kv_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
def get_value_buffer(self, layer_id: int):
|
def get_value_buffer(self, layer_id: int):
|
||||||
if self.layer_transfer_counter is not None:
|
if self.layer_transfer_counter is not None:
|
||||||
self.layer_transfer_counter.wait_until(layer_id)
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
return self.kv_buffer[layer_id - self.start_layer][
|
||||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
..., : self.kv_lora_rank
|
||||||
|
].view(self.dtype)
|
||||||
|
return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
|
||||||
|
|
||||||
def get_kv_buffer(self, layer_id: int):
|
def get_kv_buffer(self, layer_id: int):
|
||||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||||
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
|
|||||||
if cache_k.dtype != self.dtype:
|
if cache_k.dtype != self.dtype:
|
||||||
cache_k = cache_k.to(self.dtype)
|
cache_k = cache_k.to(self.dtype)
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
|
||||||
|
self.store_dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.kv_buffer[layer_id][loc] = cache_k
|
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||||
|
|
||||||
def set_mla_kv_buffer(
|
def set_mla_kv_buffer(
|
||||||
self,
|
self,
|
||||||
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
|
|||||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||||
# transfer prepared data from host to device
|
# transfer prepared data from host to device
|
||||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||||
self.kv_buffer[layer_id][indices] = flat_data
|
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
||||||
|
|
||||||
|
|
||||||
class DoubleSparseTokenToKVPool(KVCache):
|
class DoubleSparseTokenToKVPool(KVCache):
|
||||||
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
device: str,
|
device: str,
|
||||||
heavy_channel_num: int,
|
heavy_channel_num: int,
|
||||||
enable_memory_saver: bool,
|
enable_memory_saver: bool,
|
||||||
|
start_layer: Optional[int] = None,
|
||||||
|
end_layer: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.start_layer = start_layer or 0
|
||||||
|
self.end_layer = end_layer or layer_num - 1
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
return self.k_buffer[layer_id]
|
return self.k_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
def get_value_buffer(self, layer_id: int):
|
def get_value_buffer(self, layer_id: int):
|
||||||
return self.v_buffer[layer_id]
|
return self.v_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
def get_label_buffer(self, layer_id: int):
|
def get_label_buffer(self, layer_id: int):
|
||||||
return self.label_buffer[layer_id]
|
return self.label_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
def get_kv_buffer(self, layer_id: int):
|
def get_kv_buffer(self, layer_id: int):
|
||||||
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
return (
|
||||||
|
self.k_buffer[layer_id - self.start_layer],
|
||||||
|
self.v_buffer[layer_id - self.start_layer],
|
||||||
|
)
|
||||||
|
|
||||||
def set_kv_buffer(
|
def set_kv_buffer(
|
||||||
self,
|
self,
|
||||||
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
):
|
):
|
||||||
# NOTE(Andy): ignore the dtype check
|
# NOTE(Andy): ignore the dtype check
|
||||||
layer_id = layer.layer_id
|
layer_id = layer.layer_id
|
||||||
self.k_buffer[layer_id][loc] = cache_k
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||||
self.v_buffer[layer_id][loc] = cache_v
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||||
self.label_buffer[layer_id][loc] = cache_label
|
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
||||||
|
|
||||||
def get_flat_data(self, indices):
|
def get_flat_data(self, indices):
|
||||||
pass
|
pass
|
||||||
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
return self.kv_buffer[:, :, indices]
|
return self.kv_buffer[:, :, indices]
|
||||||
|
|
||||||
def get_flat_data_by_layer(self, indices, layer_id):
|
def get_flat_data_by_layer(self, indices, layer_id):
|
||||||
return self.kv_buffer[:, layer_id, indices]
|
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
||||||
|
|
||||||
def assign_flat_data(self, indices, flat_data):
|
def assign_flat_data(self, indices, flat_data):
|
||||||
self.kv_buffer[:, :, indices] = flat_data
|
self.kv_buffer[:, :, indices] = flat_data
|
||||||
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
for i in range(len(device_indices_cpu)):
|
for i in range(len(device_indices_cpu)):
|
||||||
h_index = host_indices[i * self.page_size]
|
h_index = host_indices[i * self.page_size]
|
||||||
d_index = device_indices_cpu[i]
|
d_index = device_indices_cpu[i]
|
||||||
device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
device_pool.k_buffer[layer_id - self.start_layer][
|
||||||
self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
|
d_index : d_index + self.page_size
|
||||||
|
].copy_(
|
||||||
|
self.kv_buffer[
|
||||||
|
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||||
|
],
|
||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
device_pool.v_buffer[layer_id - self.start_layer][
|
||||||
self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
|
d_index : d_index + self.page_size
|
||||||
|
].copy_(
|
||||||
|
self.kv_buffer[
|
||||||
|
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||||
|
],
|
||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
return self.kv_buffer[:, indices]
|
return self.kv_buffer[:, indices]
|
||||||
|
|
||||||
def get_flat_data_by_layer(self, indices, layer_id):
|
def get_flat_data_by_layer(self, indices, layer_id):
|
||||||
return self.kv_buffer[layer_id, indices]
|
return self.kv_buffer[layer_id - self.start_layer, indices]
|
||||||
|
|
||||||
def assign_flat_data(self, indices, flat_data):
|
def assign_flat_data(self, indices, flat_data):
|
||||||
self.kv_buffer[:, indices] = flat_data
|
self.kv_buffer[:, indices] = flat_data
|
||||||
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
for i in range(len(device_indices_cpu)):
|
for i in range(len(device_indices_cpu)):
|
||||||
h_index = host_indices[i * self.page_size]
|
h_index = host_indices[i * self.page_size]
|
||||||
d_index = device_indices_cpu[i]
|
d_index = device_indices_cpu[i]
|
||||||
device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
device_pool.kv_buffer[layer_id - self.start_layer][
|
||||||
self.kv_buffer[layer_id, h_index : h_index + self.page_size],
|
d_index : d_index + self.page_size
|
||||||
|
].copy_(
|
||||||
|
self.kv_buffer[
|
||||||
|
layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||||
|
],
|
||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable
|
||||||
@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
|
PPProxyTensors,
|
||||||
)
|
)
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
get_device_memory_capacity,
|
get_device_memory_capacity,
|
||||||
is_hip,
|
is_hip,
|
||||||
|
rank0_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -188,10 +191,11 @@ class CudaGraphRunner:
|
|||||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||||
self.tp_size = model_runner.server_args.tp_size
|
self.tp_size = model_runner.server_args.tp_size
|
||||||
self.dp_size = model_runner.server_args.dp_size
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
|
self.pp_size = model_runner.server_args.pp_size
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||||
|
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
||||||
self.capture_forward_mode = ForwardMode.DECODE
|
self.capture_forward_mode = ForwardMode.DECODE
|
||||||
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||||
self.num_tokens_per_bs = 1
|
self.num_tokens_per_bs = 1
|
||||||
@@ -234,6 +238,19 @@ class CudaGraphRunner:
|
|||||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
||||||
|
|
||||||
|
# pipeline parallelism
|
||||||
|
if self.pp_size > 1:
|
||||||
|
self.pp_proxy_tensors = {
|
||||||
|
"hidden_states": torch.zeros(
|
||||||
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
"residual": torch.zeros(
|
||||||
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
# Speculative_inference
|
# Speculative_inference
|
||||||
if (
|
if (
|
||||||
model_runner.spec_algorithm.is_eagle3()
|
model_runner.spec_algorithm.is_eagle3()
|
||||||
@@ -384,6 +401,12 @@ class CudaGraphRunner:
|
|||||||
encoder_lens = None
|
encoder_lens = None
|
||||||
mrope_positions = self.mrope_positions[:, :bs]
|
mrope_positions = self.mrope_positions[:, :bs]
|
||||||
|
|
||||||
|
# pipeline parallelism
|
||||||
|
if self.pp_size > 1:
|
||||||
|
pp_proxy_tensors = PPProxyTensors(
|
||||||
|
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
||||||
|
)
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
self.global_num_tokens_gpu.copy_(
|
self.global_num_tokens_gpu.copy_(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
@@ -456,8 +479,20 @@ class CudaGraphRunner:
|
|||||||
# Clean intermediate result cache for DP attention
|
# Clean intermediate result cache for DP attention
|
||||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
|
|
||||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
kwargs = {}
|
||||||
return logits_output.next_token_logits, logits_output.hidden_states
|
if (
|
||||||
|
self.pp_size > 1
|
||||||
|
and "pp_proxy_tensors" in inspect.signature(forward).parameters
|
||||||
|
):
|
||||||
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
||||||
|
|
||||||
|
logits_output_or_pp_proxy_tensors = forward(
|
||||||
|
input_ids,
|
||||||
|
forward_batch.positions,
|
||||||
|
forward_batch,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return logits_output_or_pp_proxy_tensors
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@@ -490,7 +525,11 @@ class CudaGraphRunner:
|
|||||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||||
self.capture()
|
self.capture()
|
||||||
|
|
||||||
def replay_prepare(self, forward_batch: ForwardBatch):
|
def replay_prepare(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
):
|
||||||
self.recapture_if_needed(forward_batch)
|
self.recapture_if_needed(forward_batch)
|
||||||
|
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
@@ -519,6 +558,11 @@ class CudaGraphRunner:
|
|||||||
self.seq_lens_cpu.fill_(1)
|
self.seq_lens_cpu.fill_(1)
|
||||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||||
|
|
||||||
|
if pp_proxy_tensors:
|
||||||
|
for key in self.pp_proxy_tensors.keys():
|
||||||
|
dim = pp_proxy_tensors[key].shape[0]
|
||||||
|
self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
if forward_batch.mrope_positions is not None:
|
if forward_batch.mrope_positions is not None:
|
||||||
@@ -547,10 +591,13 @@ class CudaGraphRunner:
|
|||||||
self.bs = bs
|
self.bs = bs
|
||||||
|
|
||||||
def replay(
|
def replay(
|
||||||
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
self,
|
||||||
) -> LogitsProcessorOutput:
|
forward_batch: ForwardBatch,
|
||||||
|
skip_attn_backend_init: bool = False,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
||||||
if not skip_attn_backend_init:
|
if not skip_attn_backend_init:
|
||||||
self.replay_prepare(forward_batch)
|
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
||||||
else:
|
else:
|
||||||
# In speculative decoding, these two fields are still needed.
|
# In speculative decoding, these two fields are still needed.
|
||||||
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
||||||
@@ -558,17 +605,19 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
self.graphs[self.bs].replay()
|
self.graphs[self.bs].replay()
|
||||||
next_token_logits, hidden_states = self.output_buffers[self.bs]
|
output = self.output_buffers[self.bs]
|
||||||
|
if isinstance(output, LogitsProcessorOutput):
|
||||||
logits_output = LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=next_token_logits[: self.raw_num_token],
|
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
||||||
hidden_states=(
|
hidden_states=(
|
||||||
hidden_states[: self.raw_num_token]
|
output.hidden_states[: self.raw_num_token]
|
||||||
if hidden_states is not None
|
if output.hidden_states is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return logits_output
|
else:
|
||||||
|
assert isinstance(output, PPProxyTensors)
|
||||||
|
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
||||||
|
|
||||||
def get_spec_info(self, num_tokens: int):
|
def get_spec_info(self, num_tokens: int):
|
||||||
spec_info = None
|
spec_info = None
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -585,6 +585,36 @@ class ForwardBatch:
|
|||||||
self.prepare_chunked_kv_indices(device)
|
self.prepare_chunked_kv_indices(device)
|
||||||
|
|
||||||
|
|
||||||
|
class PPProxyTensors:
|
||||||
|
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
|
||||||
|
tensors: Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(self, tensors):
|
||||||
|
# manually define this function, so that
|
||||||
|
# Dynamo knows `IntermediateTensors()` comes from this file.
|
||||||
|
# Otherwise, dataclass will generate this function by evaluating
|
||||||
|
# a string, and we will lose the information about the source file.
|
||||||
|
self.tensors = tensors
|
||||||
|
|
||||||
|
def __getitem__(self, key: Union[str, slice]):
|
||||||
|
if isinstance(key, str):
|
||||||
|
return self.tensors[key]
|
||||||
|
elif isinstance(key, slice):
|
||||||
|
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: torch.Tensor):
|
||||||
|
self.tensors[key] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.tensors)
|
||||||
|
|
||||||
|
def __eq__(self, other: object):
|
||||||
|
return isinstance(other, self.__class__) and self
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"PPProxyTensors(tensors={self.tensors})"
|
||||||
|
|
||||||
|
|
||||||
def compute_position_triton(
|
def compute_position_triton(
|
||||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -13,8 +13,10 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
|
|
||||||
|
import collections
|
||||||
import datetime
|
import datetime
|
||||||
import gc
|
import gc
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.model_loader.loader import (
|
from sglang.srt.model_loader.loader import (
|
||||||
DefaultModelLoader,
|
DefaultModelLoader,
|
||||||
@@ -111,6 +113,8 @@ class ModelRunner:
|
|||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
|
pp_rank: int,
|
||||||
|
pp_size: int,
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
is_draft_worker: bool = False,
|
is_draft_worker: bool = False,
|
||||||
@@ -124,6 +128,8 @@ class ModelRunner:
|
|||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
|
self.pp_rank = pp_rank
|
||||||
|
self.pp_size = pp_size
|
||||||
self.dist_port = nccl_port
|
self.dist_port = nccl_port
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.is_draft_worker = is_draft_worker
|
self.is_draft_worker = is_draft_worker
|
||||||
@@ -149,24 +155,24 @@ class ModelRunner:
|
|||||||
global_server_args_dict.update(
|
global_server_args_dict.update(
|
||||||
{
|
{
|
||||||
"attention_backend": server_args.attention_backend,
|
"attention_backend": server_args.attention_backend,
|
||||||
"sampling_backend": server_args.sampling_backend,
|
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||||
"torchao_config": server_args.torchao_config,
|
"deepep_mode": server_args.deepep_mode,
|
||||||
|
"device": server_args.device,
|
||||||
|
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
|
||||||
|
"disable_radix_cache": server_args.disable_radix_cache,
|
||||||
"enable_nan_detection": server_args.enable_nan_detection,
|
"enable_nan_detection": server_args.enable_nan_detection,
|
||||||
"enable_dp_attention": server_args.enable_dp_attention,
|
"enable_dp_attention": server_args.enable_dp_attention,
|
||||||
"enable_ep_moe": server_args.enable_ep_moe,
|
"enable_ep_moe": server_args.enable_ep_moe,
|
||||||
"enable_deepep_moe": server_args.enable_deepep_moe,
|
"enable_deepep_moe": server_args.enable_deepep_moe,
|
||||||
"deepep_mode": server_args.deepep_mode,
|
|
||||||
"device": server_args.device,
|
|
||||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
|
||||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
|
||||||
"disable_radix_cache": server_args.disable_radix_cache,
|
|
||||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||||
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
||||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
|
||||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
|
||||||
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
||||||
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||||
|
"torchao_config": server_args.torchao_config,
|
||||||
|
"sampling_backend": server_args.sampling_backend,
|
||||||
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||||
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||||
"use_mla_backend": self.use_mla_backend,
|
"use_mla_backend": self.use_mla_backend,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -184,6 +190,11 @@ class ModelRunner:
|
|||||||
# If it is a draft model, tp_group can be different
|
# If it is a draft model, tp_group can be different
|
||||||
self.initialize(min_per_gpu_memory)
|
self.initialize(min_per_gpu_memory)
|
||||||
|
|
||||||
|
# temporary cached values
|
||||||
|
self.support_pp = (
|
||||||
|
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
||||||
|
)
|
||||||
|
|
||||||
def initialize(self, min_per_gpu_memory: float):
|
def initialize(self, min_per_gpu_memory: float):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
@@ -194,6 +205,12 @@ class ModelRunner:
|
|||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
||||||
|
self.end_layer = getattr(
|
||||||
|
self.model, "end_layer", self.model_config.num_hidden_layers
|
||||||
|
)
|
||||||
|
self.num_effective_layers = self.end_layer - self.start_layer
|
||||||
|
|
||||||
# Apply torchao quantization
|
# Apply torchao quantization
|
||||||
torchao_applied = getattr(self.model, "torchao_applied", False)
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
||||||
# In layered loading, torchao may have been applied
|
# In layered loading, torchao may have been applied
|
||||||
@@ -360,18 +377,22 @@ class ModelRunner:
|
|||||||
# Only initialize the distributed environment on the target model worker.
|
# Only initialize the distributed environment on the target model worker.
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
world_size=self.tp_size,
|
world_size=self.tp_size * self.pp_size,
|
||||||
rank=self.tp_rank,
|
rank=self.tp_size * self.pp_rank + self.tp_rank,
|
||||||
local_rank=self.gpu_id,
|
local_rank=self.gpu_id,
|
||||||
distributed_init_method=dist_init_method,
|
distributed_init_method=dist_init_method,
|
||||||
timeout=self.server_args.dist_timeout,
|
timeout=self.server_args.dist_timeout,
|
||||||
)
|
)
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
initialize_model_parallel(
|
||||||
|
tensor_model_parallel_size=self.tp_size,
|
||||||
|
pipeline_model_parallel_size=self.pp_size,
|
||||||
|
)
|
||||||
initialize_dp_attention(
|
initialize_dp_attention(
|
||||||
enable_dp_attention=self.server_args.enable_dp_attention,
|
enable_dp_attention=self.server_args.enable_dp_attention,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
dp_size=self.server_args.dp_size,
|
dp_size=self.server_args.dp_size,
|
||||||
|
pp_size=self.server_args.pp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
min_per_gpu_memory = get_available_gpu_memory(
|
min_per_gpu_memory = get_available_gpu_memory(
|
||||||
@@ -698,6 +719,8 @@ class ModelRunner:
|
|||||||
if not self.is_draft_worker
|
if not self.is_draft_worker
|
||||||
else self.model_config.hf_config.num_nextn_predict_layers
|
else self.model_config.hf_config.num_nextn_predict_layers
|
||||||
)
|
)
|
||||||
|
# FIXME: pipeline parallelism is not compatible with mla backend
|
||||||
|
assert self.pp_size == 1
|
||||||
cell_size = (
|
cell_size = (
|
||||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||||
* num_layers
|
* num_layers
|
||||||
@@ -707,7 +730,7 @@ class ModelRunner:
|
|||||||
cell_size = (
|
cell_size = (
|
||||||
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
||||||
* self.model_config.head_dim
|
* self.model_config.head_dim
|
||||||
* self.model_config.num_hidden_layers
|
* self.num_effective_layers
|
||||||
* 2
|
* 2
|
||||||
* torch._utils._element_size(self.kv_cache_dtype)
|
* torch._utils._element_size(self.kv_cache_dtype)
|
||||||
)
|
)
|
||||||
@@ -819,9 +842,11 @@ class ModelRunner:
|
|||||||
self.model_config.num_hidden_layers
|
self.model_config.num_hidden_layers
|
||||||
if not self.is_draft_worker
|
if not self.is_draft_worker
|
||||||
else self.model_config.hf_config.num_nextn_predict_layers
|
else self.model_config.hf_config.num_nextn_predict_layers
|
||||||
),
|
), # PP is not compatible with mla backend
|
||||||
device=self.device,
|
device=self.device,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
|
start_layer=self.start_layer,
|
||||||
|
end_layer=self.end_layer,
|
||||||
)
|
)
|
||||||
elif self.server_args.enable_double_sparsity:
|
elif self.server_args.enable_double_sparsity:
|
||||||
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
||||||
@@ -830,10 +855,12 @@ class ModelRunner:
|
|||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.num_effective_layers,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
|
start_layer=self.start_layer,
|
||||||
|
end_layer=self.end_layer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.token_to_kv_pool = MHATokenToKVPool(
|
self.token_to_kv_pool = MHATokenToKVPool(
|
||||||
@@ -842,9 +869,11 @@ class ModelRunner:
|
|||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.num_effective_layers,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
|
start_layer=self.start_layer,
|
||||||
|
end_layer=self.end_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.token_to_kv_pool_allocator is None:
|
if self.token_to_kv_pool_allocator is None:
|
||||||
@@ -957,7 +986,7 @@ class ModelRunner:
|
|||||||
with open(self.server_args.ds_channel_config_path, "r") as f:
|
with open(self.server_args.ds_channel_config_path, "r") as f:
|
||||||
channel_config = json.load(f)
|
channel_config = json.load(f)
|
||||||
|
|
||||||
for i in range(self.model_config.num_hidden_layers):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
key = "model.layers." + str(i) + ".self_attn" + selected_channel
|
key = "model.layers." + str(i) + ".self_attn" + selected_channel
|
||||||
self.sorted_channels.append(
|
self.sorted_channels.append(
|
||||||
torch.tensor(channel_config[key])[
|
torch.tensor(channel_config[key])[
|
||||||
@@ -997,64 +1026,82 @@ class ModelRunner:
|
|||||||
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
||||||
tensor_parallel(self.model, device_mesh)
|
tensor_parallel(self.model, device_mesh)
|
||||||
|
|
||||||
def forward_decode(self, forward_batch: ForwardBatch):
|
def forward_decode(
|
||||||
|
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
|
||||||
|
) -> LogitsProcessorOutput:
|
||||||
self.attn_backend.init_forward_metadata(forward_batch)
|
self.attn_backend.init_forward_metadata(forward_batch)
|
||||||
|
# FIXME: add pp_proxy_tensors arg to all models
|
||||||
|
kwargs = {}
|
||||||
|
if self.support_pp:
|
||||||
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
self,
|
||||||
):
|
forward_batch: ForwardBatch,
|
||||||
|
skip_attn_backend_init: bool = False,
|
||||||
|
pp_proxy_tensors=None,
|
||||||
|
) -> LogitsProcessorOutput:
|
||||||
if not skip_attn_backend_init:
|
if not skip_attn_backend_init:
|
||||||
self.attn_backend.init_forward_metadata(forward_batch)
|
self.attn_backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
if self.is_generation:
|
kwargs = {}
|
||||||
if forward_batch.input_embeds is None:
|
if self.support_pp:
|
||||||
return self.model.forward(
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
if forward_batch.input_embeds is not None:
|
||||||
)
|
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
|
||||||
else:
|
if not self.is_generation:
|
||||||
|
kwargs["get_embedding"] = True
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
forward_batch.input_ids,
|
forward_batch.input_ids,
|
||||||
forward_batch.positions,
|
forward_batch.positions,
|
||||||
forward_batch,
|
forward_batch,
|
||||||
input_embeds=forward_batch.input_embeds.bfloat16(),
|
**kwargs,
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Only embedding models have get_embedding parameter
|
|
||||||
return self.model.forward(
|
|
||||||
forward_batch.input_ids,
|
|
||||||
forward_batch.positions,
|
|
||||||
forward_batch,
|
|
||||||
get_embedding=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_idle(self, forward_batch: ForwardBatch):
|
def forward_idle(
|
||||||
|
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
|
||||||
|
) -> LogitsProcessorOutput:
|
||||||
|
kwargs = {}
|
||||||
|
if self.support_pp:
|
||||||
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids,
|
||||||
|
forward_batch.positions,
|
||||||
|
forward_batch,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
self,
|
||||||
) -> LogitsProcessorOutput:
|
forward_batch: ForwardBatch,
|
||||||
if (
|
skip_attn_backend_init: bool = False,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
||||||
|
can_run_cuda_graph = bool(
|
||||||
forward_batch.forward_mode.is_cuda_graph()
|
forward_batch.forward_mode.is_cuda_graph()
|
||||||
and self.cuda_graph_runner
|
and self.cuda_graph_runner
|
||||||
and self.cuda_graph_runner.can_run(forward_batch)
|
and self.cuda_graph_runner.can_run(forward_batch)
|
||||||
):
|
)
|
||||||
|
if can_run_cuda_graph:
|
||||||
return self.cuda_graph_runner.replay(
|
return self.cuda_graph_runner.replay(
|
||||||
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
forward_batch,
|
||||||
|
skip_attn_backend_init=skip_attn_backend_init,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
return self.forward_decode(forward_batch)
|
return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
||||||
elif forward_batch.forward_mode.is_extend():
|
elif forward_batch.forward_mode.is_extend():
|
||||||
return self.forward_extend(
|
return self.forward_extend(
|
||||||
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
forward_batch,
|
||||||
|
skip_attn_backend_init=skip_attn_backend_init,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
elif forward_batch.forward_mode.is_idle():
|
elif forward_batch.forward_mode.is_idle():
|
||||||
return self.forward_idle(forward_batch)
|
return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
||||||
|
|
||||||
|
|||||||
@@ -17,13 +17,14 @@
|
|||||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import (
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
kv_cache_scales_loader,
|
kv_cache_scales_loader,
|
||||||
@@ -275,21 +277,31 @@ class LlamaModel(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("embed_tokens", prefix),
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
)
|
)
|
||||||
self.layers = make_layers(
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
|
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda idx, prefix: LlamaDecoderLayer(
|
lambda idx, prefix: LlamaDecoderLayer(
|
||||||
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
||||||
),
|
),
|
||||||
|
pp_rank=self.pp_group.rank_in_group,
|
||||||
|
pp_size=self.pp_group.world_size,
|
||||||
prefix="model.layers",
|
prefix="model.layers",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer(return_tuple=True)
|
||||||
self.layers_to_capture = []
|
self.layers_to_capture = []
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -298,14 +310,23 @@ class LlamaModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
if input_embeds is None:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
else:
|
else:
|
||||||
hidden_states = input_embeds
|
hidden_states = input_embeds
|
||||||
residual = None
|
residual = None
|
||||||
|
else:
|
||||||
|
assert pp_proxy_tensors is not None
|
||||||
|
# FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
|
||||||
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||||
|
residual = pp_proxy_tensors["residual"]
|
||||||
|
deferred_norm = None
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = []
|
||||||
for i in range(len(self.layers)):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
if i in self.layers_to_capture:
|
if i in self.layers_to_capture:
|
||||||
aux_hidden_states.append(hidden_states + residual)
|
aux_hidden_states.append(hidden_states + residual)
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
@@ -315,6 +336,15 @@ class LlamaModel(nn.Module):
|
|||||||
forward_batch,
|
forward_batch,
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
return PPProxyTensors(
|
||||||
|
{
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
if len(aux_hidden_states) == 0:
|
if len(aux_hidden_states) == 0:
|
||||||
@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
||||||
@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
get_embedding: bool = False,
|
get_embedding: bool = False,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> LogitsProcessorOutput:
|
) -> LogitsProcessorOutput:
|
||||||
aux_hidden_states = None
|
|
||||||
if self.capture_aux_hidden_states:
|
|
||||||
hidden_states, aux_hidden_states = self.model(
|
|
||||||
input_ids, positions, forward_batch, input_embeds
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids, positions, forward_batch, input_embeds
|
input_ids,
|
||||||
|
positions,
|
||||||
|
forward_batch,
|
||||||
|
input_embeds,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
aux_hidden_states = None
|
||||||
|
if self.capture_aux_hidden_states:
|
||||||
|
hidden_states, aux_hidden_states = hidden_states
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
input_ids,
|
||||||
|
hidden_states,
|
||||||
|
self.lm_head,
|
||||||
|
forward_batch,
|
||||||
|
aux_hidden_states,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return self.pooler(hidden_states, forward_batch)
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_layer(self):
|
||||||
|
return self.model.start_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_layer(self):
|
||||||
|
return self.model.end_layer
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
layer_id = get_layer_id(name)
|
||||||
|
if (
|
||||||
|
layer_id is not None
|
||||||
|
and hasattr(self.model, "start_layer")
|
||||||
|
and (
|
||||||
|
layer_id < self.model.start_layer
|
||||||
|
or layer_id >= self.model.end_layer
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
continue
|
continue
|
||||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
self.model.load_kv_cache_scales(quantization_param_path)
|
self.model.load_kv_cache_scales(quantization_param_path)
|
||||||
|
|
||||||
def set_eagle3_layers_to_capture(self):
|
def set_eagle3_layers_to_capture(self):
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
return
|
||||||
|
|
||||||
self.capture_aux_hidden_states = True
|
self.capture_aux_hidden_states = True
|
||||||
num_layers = self.config.num_hidden_layers
|
num_layers = self.config.num_hidden_layers
|
||||||
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
||||||
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
||||||
|
|
||||||
@@ -431,6 +431,7 @@ class Llama4Model(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
if input_embeds is None:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|||||||
@@ -25,13 +25,14 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_pp_group
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||||
|
|
||||||
|
|
||||||
@@ -86,6 +87,7 @@ class LlamaModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if input_embeds is None:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
|
|||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.model = LlamaModel(
|
self.model = LlamaModel(
|
||||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_pp_group
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
|
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
|
||||||
|
|
||||||
|
|
||||||
@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if input_embeds is None:
|
if input_embeds is None:
|
||||||
embeds = self.embed_tokens(input_ids)
|
embeds = self.embed_tokens(input_ids)
|
||||||
@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
|||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
|
||||||
if self.config.num_hidden_layers != 1:
|
if self.config.num_hidden_layers != 1:
|
||||||
raise ValueError("EAGLE3 currently only supports 1 layer")
|
raise ValueError("EAGLE3 currently only supports 1 layer")
|
||||||
|
|||||||
@@ -78,6 +78,8 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Other runtime options
|
# Other runtime options
|
||||||
tp_size: int = 1
|
tp_size: int = 1
|
||||||
|
pp_size: int = 1
|
||||||
|
max_micro_batch_size: Optional[int] = None
|
||||||
stream_interval: int = 1
|
stream_interval: int = 1
|
||||||
stream_output: bool = False
|
stream_output: bool = False
|
||||||
random_seed: Optional[int] = None
|
random_seed: Optional[int] = None
|
||||||
@@ -222,16 +224,20 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Set mem fraction static, which depends on the tensor parallelism size
|
# Set mem fraction static, which depends on the tensor parallelism size
|
||||||
if self.mem_fraction_static is None:
|
if self.mem_fraction_static is None:
|
||||||
if self.tp_size >= 16:
|
parallel_size = self.tp_size * self.pp_size
|
||||||
|
if gpu_mem <= 81920:
|
||||||
|
if parallel_size >= 16:
|
||||||
self.mem_fraction_static = 0.79
|
self.mem_fraction_static = 0.79
|
||||||
elif self.tp_size >= 8:
|
elif parallel_size >= 8:
|
||||||
self.mem_fraction_static = 0.81
|
self.mem_fraction_static = 0.81
|
||||||
elif self.tp_size >= 4:
|
elif parallel_size >= 4:
|
||||||
self.mem_fraction_static = 0.85
|
self.mem_fraction_static = 0.85
|
||||||
elif self.tp_size >= 2:
|
elif parallel_size >= 2:
|
||||||
self.mem_fraction_static = 0.87
|
self.mem_fraction_static = 0.87
|
||||||
else:
|
else:
|
||||||
self.mem_fraction_static = 0.88
|
self.mem_fraction_static = 0.88
|
||||||
|
else:
|
||||||
|
self.mem_fraction_static = 0.88
|
||||||
if gpu_mem > 96 * 1024:
|
if gpu_mem > 96 * 1024:
|
||||||
mem_fraction = self.mem_fraction_static
|
mem_fraction = self.mem_fraction_static
|
||||||
self.mem_fraction_static = min(
|
self.mem_fraction_static = min(
|
||||||
@@ -244,6 +250,8 @@ class ServerArgs:
|
|||||||
if self.chunked_prefill_size is None:
|
if self.chunked_prefill_size is None:
|
||||||
if gpu_mem is not None and gpu_mem < 25_000:
|
if gpu_mem is not None and gpu_mem < 25_000:
|
||||||
self.chunked_prefill_size = 2048
|
self.chunked_prefill_size = 2048
|
||||||
|
elif self.disaggregation_mode != "null":
|
||||||
|
self.chunked_prefill_size = 16384
|
||||||
else:
|
else:
|
||||||
self.chunked_prefill_size = 8192
|
self.chunked_prefill_size = 8192
|
||||||
assert self.chunked_prefill_size % self.page_size == 0
|
assert self.chunked_prefill_size % self.page_size == 0
|
||||||
@@ -643,6 +651,19 @@ class ServerArgs:
|
|||||||
default=ServerArgs.tp_size,
|
default=ServerArgs.tp_size,
|
||||||
help="The tensor parallelism size.",
|
help="The tensor parallelism size.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pipeline-parallel-size",
|
||||||
|
"--pp-size",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.pp_size,
|
||||||
|
help="The pipeline parallelism size.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-micro-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.max_micro_batch_size,
|
||||||
|
help="The maximum micro batch size in pipeline parallelism.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--stream-interval",
|
"--stream-interval",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -1232,6 +1253,7 @@ class ServerArgs:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
|
args.pp_size = args.pipeline_parallel_size
|
||||||
args.dp_size = args.data_parallel_size
|
args.dp_size = args.data_parallel_size
|
||||||
args.ep_size = args.expert_parallel_size
|
args.ep_size = args.expert_parallel_size
|
||||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
@@ -1245,8 +1267,19 @@ class ServerArgs:
|
|||||||
|
|
||||||
def check_server_args(self):
|
def check_server_args(self):
|
||||||
assert (
|
assert (
|
||||||
self.tp_size % self.nnodes == 0
|
self.tp_size * self.pp_size
|
||||||
), "tp_size must be divisible by number of nodes"
|
) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
|
||||||
|
|
||||||
|
# FIXME pp constraints
|
||||||
|
if self.pp_size > 1:
|
||||||
|
logger.warning(f"Turn off overlap scheule for pipeline parallelism.")
|
||||||
|
self.disable_overlap_schedule = True
|
||||||
|
assert (
|
||||||
|
self.disable_overlap_schedule
|
||||||
|
and self.speculative_algorithm is None
|
||||||
|
and not self.enable_mixed_chunk
|
||||||
|
), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
||||||
), "multi-node data parallel is not supported unless dp attention!"
|
), "multi-node data parallel is not supported unless dp attention!"
|
||||||
|
|||||||
@@ -106,11 +106,12 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# Init draft worker
|
# Init draft worker
|
||||||
with empty_context():
|
with empty_context():
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
server_args=server_args,
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
server_args=server_args,
|
pp_rank=0, # FIXME
|
||||||
nccl_port=nccl_port,
|
|
||||||
dp_rank=dp_rank,
|
dp_rank=dp_rank,
|
||||||
|
nccl_port=nccl_port,
|
||||||
is_draft_worker=True,
|
is_draft_worker=True,
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Common utilities."""
|
"""Common utilities."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import builtins
|
import builtins
|
||||||
import ctypes
|
import ctypes
|
||||||
@@ -414,16 +415,40 @@ class LayerFn(Protocol):
|
|||||||
def make_layers(
|
def make_layers(
|
||||||
num_hidden_layers: int,
|
num_hidden_layers: int,
|
||||||
layer_fn: LayerFn,
|
layer_fn: LayerFn,
|
||||||
|
pp_rank: Optional[int] = None,
|
||||||
|
pp_size: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
return_tuple: bool = False,
|
||||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||||
"""Make a list of layers with the given layer function"""
|
"""Make a list of layers with the given layer function"""
|
||||||
|
# circula imports
|
||||||
|
from sglang.srt.distributed import get_pp_indices
|
||||||
|
from sglang.srt.layers.utils import PPMissingLayer
|
||||||
|
|
||||||
|
assert not pp_size or num_hidden_layers >= pp_size
|
||||||
|
start_layer, end_layer = (
|
||||||
|
get_pp_indices(
|
||||||
|
num_hidden_layers,
|
||||||
|
pp_rank,
|
||||||
|
pp_size,
|
||||||
|
)
|
||||||
|
if pp_rank is not None and pp_size is not None
|
||||||
|
else (0, num_hidden_layers)
|
||||||
|
)
|
||||||
modules = torch.nn.ModuleList(
|
modules = torch.nn.ModuleList(
|
||||||
[
|
[PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
|
||||||
|
+ [
|
||||||
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
|
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
|
||||||
for idx in range(num_hidden_layers)
|
for idx in range(start_layer, end_layer)
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
PPMissingLayer(return_tuple=return_tuple)
|
||||||
|
for _ in range(end_layer, num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
if pp_rank is None or pp_size is None:
|
||||||
return modules
|
return modules
|
||||||
|
return modules, start_layer, end_layer
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed: int) -> None:
|
def set_random_seed(seed: int) -> None:
|
||||||
@@ -877,7 +902,7 @@ def broadcast_pyobj(
|
|||||||
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
|
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == src:
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
||||||
dist.broadcast(tensor_size, src=src, group=dist_group)
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
||||||
@@ -909,6 +934,50 @@ def broadcast_pyobj(
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def point_to_point_pyobj(
|
||||||
|
data: List[Any],
|
||||||
|
rank: int,
|
||||||
|
group: Optional[torch.distributed.ProcessGroup] = None,
|
||||||
|
src: int = 0,
|
||||||
|
dst: int = 1,
|
||||||
|
):
|
||||||
|
"""Send data from src to dst in group."""
|
||||||
|
|
||||||
|
if rank == src:
|
||||||
|
if len(data) == 0:
|
||||||
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||||
|
dist.send(tensor_size, dst=dst, group=group)
|
||||||
|
else:
|
||||||
|
serialized_data = pickle.dumps(data)
|
||||||
|
size = len(serialized_data)
|
||||||
|
tensor_data = torch.ByteTensor(
|
||||||
|
np.frombuffer(serialized_data, dtype=np.uint8)
|
||||||
|
)
|
||||||
|
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||||
|
|
||||||
|
dist.send(tensor_size, dst=dst, group=group)
|
||||||
|
dist.send(tensor_data, dst=dst, group=group)
|
||||||
|
return data
|
||||||
|
|
||||||
|
elif rank == dst:
|
||||||
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||||
|
dist.recv(tensor_size, src=src, group=group)
|
||||||
|
size = tensor_size.item()
|
||||||
|
|
||||||
|
if size == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||||
|
dist.recv(tensor_data, src=src, group=group)
|
||||||
|
|
||||||
|
serialized_data = bytes(tensor_data.cpu().numpy())
|
||||||
|
data = pickle.loads(serialized_data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Other ranks in pp_group do nothing
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
step_counter = 0
|
step_counter = 0
|
||||||
|
|
||||||
|
|
||||||
@@ -1732,6 +1801,13 @@ def configure_ipv6(dist_init_addr):
|
|||||||
return port, host
|
return port, host
|
||||||
|
|
||||||
|
|
||||||
|
def rank0_log(msg: str):
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
|
||||||
def rank0_print(msg: str):
|
def rank0_print(msg: str):
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
|
|||||||
@@ -770,6 +770,34 @@ def run_bench_offline_throughput(model, other_args):
|
|||||||
return output_throughput
|
return output_throughput
|
||||||
|
|
||||||
|
|
||||||
|
def run_bench_one_batch_server(
|
||||||
|
model,
|
||||||
|
base_url,
|
||||||
|
server_args,
|
||||||
|
bench_args,
|
||||||
|
other_server_args,
|
||||||
|
simulate_spec_acc_lens=None,
|
||||||
|
):
|
||||||
|
from sglang.bench_one_batch_server import run_benchmark
|
||||||
|
|
||||||
|
if simulate_spec_acc_lens is not None:
|
||||||
|
env = {**os.environ, "SIMULATE_ACC_LEN": str(simulate_spec_acc_lens)}
|
||||||
|
else:
|
||||||
|
env = None
|
||||||
|
|
||||||
|
process = popen_launch_server(
|
||||||
|
model,
|
||||||
|
base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_server_args,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
run_benchmark(server_args=server_args, bench_args=bench_args)
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
|
||||||
def lcs(X, Y):
|
def lcs(X, Y):
|
||||||
m = len(X)
|
m = len(X)
|
||||||
n = len(Y)
|
n = len(Y)
|
||||||
|
|||||||
@@ -96,6 +96,8 @@ suites = {
|
|||||||
"per-commit-8-gpu": [
|
"per-commit-8-gpu": [
|
||||||
TestFile("test_local_attn.py", 250),
|
TestFile("test_local_attn.py", 250),
|
||||||
TestFile("test_full_deepseek_v3.py", 250),
|
TestFile("test_full_deepseek_v3.py", 250),
|
||||||
|
TestFile("test_fa3.py", 30),
|
||||||
|
TestFile("test_pp_single_node.py", 150),
|
||||||
],
|
],
|
||||||
"nightly": [
|
"nightly": [
|
||||||
TestFile("test_nightly_gsm8k_eval.py"),
|
TestFile("test_nightly_gsm8k_eval.py"),
|
||||||
|
|||||||
143
test/srt/test_pp_single_node.py
Normal file
143
test/srt/test_pp_single_node.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
|
||||||
|
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval
|
||||||
|
from sglang.test.runners import DEFAULT_PROMPTS
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
run_bench_one_batch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPPAccuracy(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
# These config helps find a leak.
|
||||||
|
os.environ["SGLANG_IS_IN_CI"] = "1"
|
||||||
|
cls.base_url = "http://127.0.0.1:23333"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--pp-size",
|
||||||
|
4,
|
||||||
|
"--disable-overlap-schedule",
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
256,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.75)
|
||||||
|
# Wait a little bit so that the memory check happens.
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
# class TestPPAccuracyFlashInfer(unittest.TestCase):
|
||||||
|
# @classmethod
|
||||||
|
# def setUpClass(cls):
|
||||||
|
# # These config helps find a leak.
|
||||||
|
# os.environ["SGLANG_IS_IN_CI"] = "1"
|
||||||
|
# cls.base_url = "http://127.0.0.1:23333"
|
||||||
|
# cls.process = popen_launch_server(
|
||||||
|
# DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
# cls.base_url,
|
||||||
|
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
# other_args=[
|
||||||
|
# "--pp-size",
|
||||||
|
# 4,
|
||||||
|
# "--disable-overlap-schedule",
|
||||||
|
# "--attention-backend",
|
||||||
|
# "flashinfer",
|
||||||
|
# "--chunked-prefill-size",
|
||||||
|
# 256,
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# @classmethod
|
||||||
|
# def tearDownClass(cls):
|
||||||
|
# kill_process_tree(cls.process.pid)
|
||||||
|
#
|
||||||
|
# def test_gsm8k(self):
|
||||||
|
# args = SimpleNamespace(
|
||||||
|
# num_shots=5,
|
||||||
|
# data_path=None,
|
||||||
|
# num_questions=200,
|
||||||
|
# max_new_tokens=512,
|
||||||
|
# parallel=128,
|
||||||
|
# host="http://127.0.0.1",
|
||||||
|
# port=int(self.base_url.split(":")[-1]),
|
||||||
|
# )
|
||||||
|
# metrics = run_eval(args)
|
||||||
|
# print(f"{metrics=}")
|
||||||
|
#
|
||||||
|
# self.assertGreater(metrics["accuracy"], 0.75)
|
||||||
|
# # Wait a little bit so that the memory check happens.
|
||||||
|
# time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFixedBugs(unittest.TestCase):
|
||||||
|
def test_chunked_prefill_with_small_bs(self):
|
||||||
|
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
server_args = ServerArgs(model_path=model)
|
||||||
|
bench_args = OneBatchBenchArgs(
|
||||||
|
batch_size=(1,),
|
||||||
|
input_len=(1,),
|
||||||
|
output_len=(1,),
|
||||||
|
base_url=DEFAULT_URL_FOR_TEST,
|
||||||
|
)
|
||||||
|
other_server_args = [
|
||||||
|
"--tp-size",
|
||||||
|
2,
|
||||||
|
"--pp-size",
|
||||||
|
2,
|
||||||
|
"--disable-overlap-schedule",
|
||||||
|
"--chunked-prefill",
|
||||||
|
256,
|
||||||
|
"--max-running-requests",
|
||||||
|
2,
|
||||||
|
]
|
||||||
|
run_bench_one_batch_server(
|
||||||
|
model,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
server_args,
|
||||||
|
bench_args,
|
||||||
|
other_server_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -147,6 +147,8 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
|||||||
gpu_id=0,
|
gpu_id=0,
|
||||||
tp_rank=0,
|
tp_rank=0,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
|
pp_rank=0,
|
||||||
|
pp_size=1,
|
||||||
nccl_port=12435,
|
nccl_port=12435,
|
||||||
server_args=ServerArgs(
|
server_args=ServerArgs(
|
||||||
model_path=self.model_path,
|
model_path=self.model_path,
|
||||||
|
|||||||
Reference in New Issue
Block a user