Unify forward output datastructure (#11124)
This commit is contained in:
@@ -22,6 +22,7 @@ from typing import List, Optional, Set, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from sglang.srt.environ import envs
|
||||||
from sglang.srt.hf_transformers_utils import (
|
from sglang.srt.hf_transformers_utils import (
|
||||||
get_config,
|
get_config,
|
||||||
get_context_length,
|
get_context_length,
|
||||||
@@ -31,7 +32,7 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip, retry
|
from sglang.srt.utils import is_hip, retry
|
||||||
from sglang.utils import is_in_ci
|
from sglang.utils import is_in_ci
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -237,7 +238,7 @@ class ModelConfig:
|
|||||||
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
|
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
|
envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.get()
|
||||||
or is_in_ci() # FIXME: fix this special case
|
or is_in_ci() # FIXME: fix this special case
|
||||||
):
|
):
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
|
|||||||
@@ -689,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
self.running_mbs = [
|
self.running_mbs = [
|
||||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||||
]
|
]
|
||||||
bids = [None] * self.pp_size
|
|
||||||
pp_outputs: Optional[PPProxyTensors] = None
|
pp_outputs: Optional[PPProxyTensors] = None
|
||||||
|
|
||||||
# Either success or failed
|
# Either success or failed
|
||||||
@@ -761,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
# send the outputs to the next step
|
# send the outputs to the next step
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
if self.cur_batch:
|
if self.cur_batch:
|
||||||
next_token_ids, bids[mb_id] = (
|
next_token_ids = result.next_token_ids
|
||||||
result.next_token_ids,
|
|
||||||
result.bid,
|
|
||||||
)
|
|
||||||
pp_outputs = PPProxyTensors(
|
pp_outputs = PPProxyTensors(
|
||||||
{
|
{
|
||||||
"next_token_ids": next_token_ids,
|
"next_token_ids": next_token_ids,
|
||||||
@@ -801,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||||
extend_input_len_per_req=None,
|
extend_input_len_per_req=None,
|
||||||
extend_logprob_start_len_per_req=None,
|
extend_logprob_start_len_per_req=None,
|
||||||
bid=bids[next_mb_id],
|
|
||||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
self.process_batch_result_disagg_prefill(
|
self.process_batch_result_disagg_prefill(
|
||||||
@@ -818,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
|
|
||||||
# carry the outputs to the next stage
|
# carry the outputs to the next stage
|
||||||
if not self.pp_group.is_last_rank:
|
if not self.pp_group.is_last_rank:
|
||||||
if self.cur_batch:
|
|
||||||
bids[mb_id] = result.bid
|
|
||||||
if pp_outputs:
|
if pp_outputs:
|
||||||
# send the outputs from the last round to let the next stage worker run post processing
|
# send the outputs from the last round to let the next stage worker run post processing
|
||||||
self.pp_group.send_tensor_dict(
|
self.pp_group.send_tensor_dict(
|
||||||
@@ -838,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
|
|
||||||
# send out proxy tensors to the next stage
|
# send out proxy tensors to the next stage
|
||||||
if self.cur_batch:
|
if self.cur_batch:
|
||||||
|
# FIXME(lsyin): remove this assert
|
||||||
|
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||||
self.pp_group.send_tensor_dict(
|
self.pp_group.send_tensor_dict(
|
||||||
result.pp_hidden_states_proxy_tensors,
|
result.pp_hidden_states_proxy_tensors.tensors,
|
||||||
all_gather_group=self.attn_tp_group,
|
all_gather_group=self.attn_tp_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -860,10 +860,6 @@ class Req:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Batch id
|
|
||||||
bid = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||||
"""Store all information of a batch on the scheduler."""
|
"""Store all information of a batch on the scheduler."""
|
||||||
@@ -1829,10 +1825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
|
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
|
||||||
)
|
)
|
||||||
|
|
||||||
global bid
|
|
||||||
bid += 1
|
|
||||||
return ModelWorkerBatch(
|
return ModelWorkerBatch(
|
||||||
bid=bid,
|
|
||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
input_ids=self.input_ids,
|
input_ids=self.input_ids,
|
||||||
req_pool_indices=self.req_pool_indices,
|
req_pool_indices=self.req_pool_indices,
|
||||||
@@ -1952,8 +1945,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ModelWorkerBatch:
|
class ModelWorkerBatch:
|
||||||
# The batch id
|
|
||||||
bid: int
|
|
||||||
# The forward mode
|
# The forward mode
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
# The input ids
|
# The input ids
|
||||||
|
|||||||
@@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
|||||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
ForwardBatchOutput,
|
||||||
|
ForwardMode,
|
||||||
|
PPProxyTensors,
|
||||||
|
)
|
||||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
@@ -175,7 +179,6 @@ from sglang.srt.utils import (
|
|||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_int_env_var,
|
get_int_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
is_cpu,
|
|
||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
numa_bind_to_node,
|
numa_bind_to_node,
|
||||||
point_to_point_pyobj,
|
point_to_point_pyobj,
|
||||||
@@ -194,24 +197,59 @@ logger = logging.getLogger(__name__)
|
|||||||
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||||
|
|
||||||
_is_cpu = is_cpu()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerationBatchResult:
|
class GenerationBatchResult:
|
||||||
logits_output: Optional[LogitsProcessorOutput]
|
logits_output: Optional[LogitsProcessorOutput]
|
||||||
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
|
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
||||||
next_token_ids: Optional[List[int]]
|
next_token_ids: Optional[List[int]]
|
||||||
|
can_run_cuda_graph: bool
|
||||||
|
|
||||||
|
# For output processing
|
||||||
extend_input_len_per_req: List[int]
|
extend_input_len_per_req: List[int]
|
||||||
extend_logprob_start_len_per_req: List[int]
|
extend_logprob_start_len_per_req: List[int]
|
||||||
bid: int
|
|
||||||
can_run_cuda_graph: bool
|
@classmethod
|
||||||
|
def from_forward_batch_output(
|
||||||
|
cls,
|
||||||
|
forward_batch_output: ForwardBatchOutput,
|
||||||
|
extend_input_len_per_req: List[int],
|
||||||
|
extend_logprob_start_len_per_req: List[int],
|
||||||
|
):
|
||||||
|
# TODO(lsyin): remove this workaround logic and try to unify output classes
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
logits_output=forward_batch_output.logits_output,
|
||||||
|
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
|
||||||
|
next_token_ids=forward_batch_output.next_token_ids,
|
||||||
|
extend_input_len_per_req=extend_input_len_per_req,
|
||||||
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||||
|
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pp_proxy(
|
||||||
|
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
||||||
|
):
|
||||||
|
# TODO(lsyin): also simplify this logic
|
||||||
|
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
||||||
|
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
||||||
|
proxy_dict = next_pp_outputs.tensors
|
||||||
|
return cls(
|
||||||
|
logits_output=logits_output,
|
||||||
|
pp_hidden_states_proxy_tensors=None,
|
||||||
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||||
|
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
||||||
|
extend_logprob_start_len_per_req=proxy_dict.get(
|
||||||
|
"extend_logprob_start_len_per_req", None
|
||||||
|
),
|
||||||
|
can_run_cuda_graph=can_run_cuda_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingBatchResult:
|
class EmbeddingBatchResult:
|
||||||
embeddings: torch.Tensor
|
embeddings: torch.Tensor
|
||||||
bid: int
|
|
||||||
|
|
||||||
|
|
||||||
class Scheduler(
|
class Scheduler(
|
||||||
@@ -403,6 +441,12 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
self.draft_worker = None
|
self.draft_worker = None
|
||||||
|
|
||||||
|
# Dispatch the model worker
|
||||||
|
if self.spec_algorithm.is_none():
|
||||||
|
self.model_worker = self.tp_worker
|
||||||
|
else:
|
||||||
|
self.model_worker = self.draft_worker
|
||||||
|
|
||||||
# Get token and memory info from the model worker
|
# Get token and memory info from the model worker
|
||||||
(
|
(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
@@ -959,7 +1003,6 @@ class Scheduler(
|
|||||||
self.running_mbs = [
|
self.running_mbs = [
|
||||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||||
]
|
]
|
||||||
bids = [None] * self.pp_size
|
|
||||||
pp_outputs: Optional[PPProxyTensors] = None
|
pp_outputs: Optional[PPProxyTensors] = None
|
||||||
while True:
|
while True:
|
||||||
server_is_idle = True
|
server_is_idle = True
|
||||||
@@ -980,10 +1023,7 @@ class Scheduler(
|
|||||||
# (last rank) send the outputs to the next step
|
# (last rank) send the outputs to the next step
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
if self.cur_batch:
|
if self.cur_batch:
|
||||||
next_token_ids, bids[mb_id] = (
|
next_token_ids = result.next_token_ids
|
||||||
result.next_token_ids,
|
|
||||||
result.bid,
|
|
||||||
)
|
|
||||||
if self.cur_batch.return_logprob:
|
if self.cur_batch.return_logprob:
|
||||||
pp_outputs = PPProxyTensors(
|
pp_outputs = PPProxyTensors(
|
||||||
{
|
{
|
||||||
@@ -1031,17 +1071,10 @@ class Scheduler(
|
|||||||
logits_output = LogitsProcessorOutput(**logits_output_args)
|
logits_output = LogitsProcessorOutput(**logits_output_args)
|
||||||
else:
|
else:
|
||||||
logits_output = None
|
logits_output = None
|
||||||
output_result = GenerationBatchResult(
|
|
||||||
|
output_result = GenerationBatchResult.from_pp_proxy(
|
||||||
logits_output=logits_output,
|
logits_output=logits_output,
|
||||||
pp_hidden_states_proxy_tensors=None,
|
next_pp_outputs=next_pp_outputs,
|
||||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
|
||||||
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
|
||||||
"extend_input_len_per_req", None
|
|
||||||
),
|
|
||||||
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
|
||||||
"extend_logprob_start_len_per_req", None
|
|
||||||
),
|
|
||||||
bid=bids[next_mb_id],
|
|
||||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||||
@@ -1049,8 +1082,6 @@ class Scheduler(
|
|||||||
|
|
||||||
# (not last rank)
|
# (not last rank)
|
||||||
if not self.pp_group.is_last_rank:
|
if not self.pp_group.is_last_rank:
|
||||||
if self.cur_batch:
|
|
||||||
bids[mb_id] = result.bid
|
|
||||||
# carry the outputs to the next stage
|
# carry the outputs to the next stage
|
||||||
# send the outputs from the last round to let the next stage worker run post processing
|
# send the outputs from the last round to let the next stage worker run post processing
|
||||||
if pp_outputs:
|
if pp_outputs:
|
||||||
@@ -1072,8 +1103,10 @@ class Scheduler(
|
|||||||
|
|
||||||
# send out proxy tensors to the next stage
|
# send out proxy tensors to the next stage
|
||||||
if self.cur_batch:
|
if self.cur_batch:
|
||||||
|
# FIXME(lsyin): remove this assert
|
||||||
|
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||||
self.pp_group.send_tensor_dict(
|
self.pp_group.send_tensor_dict(
|
||||||
result.pp_hidden_states_proxy_tensors,
|
result.pp_hidden_states_proxy_tensors.tensors,
|
||||||
all_gather_group=self.attn_tp_group,
|
all_gather_group=self.attn_tp_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2016,33 +2049,25 @@ class Scheduler(
|
|||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
|
|
||||||
|
batch_or_worker_batch = batch
|
||||||
|
|
||||||
if self.spec_algorithm.is_none():
|
if self.spec_algorithm.is_none():
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
# FIXME(lsyin): remove this if and finally unify the abstraction
|
||||||
|
batch_or_worker_batch = batch.get_model_worker_batch()
|
||||||
|
|
||||||
if self.pp_group.is_last_rank:
|
forward_batch_output = self.model_worker.forward_batch_generation(
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
batch_or_worker_batch
|
||||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
|
|
||||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
|
||||||
)
|
|
||||||
bid = model_worker_batch.bid
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
logits_output,
|
|
||||||
next_token_ids,
|
|
||||||
bid,
|
|
||||||
num_accepted_tokens,
|
|
||||||
can_run_cuda_graph,
|
|
||||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
|
||||||
bs = batch.batch_size()
|
|
||||||
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
|
||||||
self.spec_num_total_forward_ct += bs
|
|
||||||
self.num_generated_tokens += num_accepted_tokens
|
|
||||||
|
|
||||||
if self.pp_group.is_last_rank:
|
if not self.spec_algorithm.is_none():
|
||||||
batch.output_ids = next_token_ids
|
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
||||||
|
self.udpate_spec_metrics(
|
||||||
|
batch.batch_size(), forward_batch_output.num_accepted_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# update batch's output ids
|
||||||
|
batch.output_ids = forward_batch_output.next_token_ids
|
||||||
|
|
||||||
# These 2 values are needed for processing the output, but the values can be
|
# These 2 values are needed for processing the output, but the values can be
|
||||||
# modified by overlap schedule. So we have to copy them here so that
|
# modified by overlap schedule. So we have to copy them here so that
|
||||||
@@ -2051,6 +2076,7 @@ class Scheduler(
|
|||||||
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
||||||
else:
|
else:
|
||||||
extend_input_len_per_req = None
|
extend_input_len_per_req = None
|
||||||
|
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
extend_logprob_start_len_per_req = [
|
extend_logprob_start_len_per_req = [
|
||||||
req.extend_logprob_start_len for req in batch.reqs
|
req.extend_logprob_start_len for req in batch.reqs
|
||||||
@@ -2058,25 +2084,15 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
extend_logprob_start_len_per_req = None
|
extend_logprob_start_len_per_req = None
|
||||||
|
|
||||||
ret = GenerationBatchResult(
|
return GenerationBatchResult.from_forward_batch_output(
|
||||||
logits_output=logits_output if self.pp_group.is_last_rank else None,
|
forward_batch_output=forward_batch_output,
|
||||||
pp_hidden_states_proxy_tensors=(
|
|
||||||
pp_hidden_states_proxy_tensors
|
|
||||||
if not self.pp_group.is_last_rank
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
|
||||||
extend_input_len_per_req=extend_input_len_per_req,
|
extend_input_len_per_req=extend_input_len_per_req,
|
||||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||||
bid=bid,
|
|
||||||
can_run_cuda_graph=can_run_cuda_graph,
|
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||||
ret = EmbeddingBatchResult(
|
ret = EmbeddingBatchResult(embeddings=embeddings)
|
||||||
embeddings=embeddings, bid=model_worker_batch.bid
|
|
||||||
)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def process_batch_result(
|
def process_batch_result(
|
||||||
|
|||||||
@@ -80,6 +80,11 @@ class SchedulerMetricsMixin:
|
|||||||
kv_events_config, self.attn_dp_rank
|
kv_events_config, self.attn_dp_rank
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
|
||||||
|
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
||||||
|
self.spec_num_total_forward_ct += bs
|
||||||
|
self.num_generated_tokens += num_accepted_tokens
|
||||||
|
|
||||||
def log_prefill_stats(
|
def log_prefill_stats(
|
||||||
self: Scheduler,
|
self: Scheduler,
|
||||||
adder: PrefillAdder,
|
adder: PrefillAdder,
|
||||||
|
|||||||
@@ -173,8 +173,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
self.set_next_batch_sampling_info_done(batch)
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
embeddings, bid = result.embeddings, result.bid
|
embeddings = result.embeddings.tolist()
|
||||||
embeddings = embeddings.tolist()
|
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
|
|||||||
@@ -43,7 +43,11 @@ from sglang.srt.managers.io_struct import (
|
|||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardBatchOutput,
|
||||||
|
PPProxyTensors,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -234,9 +238,7 @@ class TpModelWorker:
|
|||||||
model_worker_batch: ModelWorkerBatch,
|
model_worker_batch: ModelWorkerBatch,
|
||||||
launch_done: Optional[threading.Event] = None,
|
launch_done: Optional[threading.Event] = None,
|
||||||
skip_sample: bool = False,
|
skip_sample: bool = False,
|
||||||
) -> Tuple[
|
) -> ForwardBatchOutput:
|
||||||
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
|
||||||
]:
|
|
||||||
# update the consumer index of hicache to the running batch
|
# update the consumer index of hicache to the running batch
|
||||||
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
||||||
|
|
||||||
@@ -271,13 +273,20 @@ class TpModelWorker:
|
|||||||
else:
|
else:
|
||||||
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
||||||
|
|
||||||
return logits_output, next_token_ids, can_run_cuda_graph
|
return ForwardBatchOutput(
|
||||||
|
logits_output=logits_output,
|
||||||
|
next_token_ids=next_token_ids,
|
||||||
|
can_run_cuda_graph=can_run_cuda_graph,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
||||||
forward_batch,
|
forward_batch,
|
||||||
pp_proxy_tensors=pp_proxy_tensors,
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
return pp_proxy_tensors.tensors, None, can_run_cuda_graph
|
return ForwardBatchOutput(
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
|
can_run_cuda_graph=can_run_cuda_graph,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
from sglang.srt.managers.overlap_utils import FutureMap
|
from sglang.srt.managers.overlap_utils import FutureMap
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import DynamicGradMode
|
from sglang.srt.utils import DynamicGradMode
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
@@ -160,13 +161,17 @@ class TpModelWorkerClient:
|
|||||||
self.future_map.resolve_future(model_worker_batch)
|
self.future_map.resolve_future(model_worker_batch)
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
|
forward_batch_output = self.worker.forward_batch_generation(
|
||||||
|
model_worker_batch,
|
||||||
|
model_worker_batch.launch_done,
|
||||||
|
# Skip sampling for prefill-only requests
|
||||||
|
skip_sample=model_worker_batch.is_prefill_only,
|
||||||
|
)
|
||||||
|
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
self.worker.forward_batch_generation(
|
forward_batch_output.logits_output,
|
||||||
model_worker_batch,
|
forward_batch_output.next_token_ids,
|
||||||
model_worker_batch.launch_done,
|
forward_batch_output.can_run_cuda_graph,
|
||||||
# Skip sampling for prefill-only requests
|
|
||||||
skip_sample=model_worker_batch.is_prefill_only,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the future token ids map
|
# Update the future token ids map
|
||||||
@@ -227,7 +232,7 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
def forward_batch_generation(
|
def forward_batch_generation(
|
||||||
self, model_worker_batch: ModelWorkerBatch
|
self, model_worker_batch: ModelWorkerBatch
|
||||||
) -> Tuple[None, torch.Tensor, bool]:
|
) -> ForwardBatchOutput:
|
||||||
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
||||||
sampling_info = model_worker_batch.sampling_info
|
sampling_info = model_worker_batch.sampling_info
|
||||||
sampling_info.update_penalties()
|
sampling_info.update_penalties()
|
||||||
@@ -250,7 +255,10 @@ class TpModelWorkerClient:
|
|||||||
future_next_token_ids = self.future_map.update_next_future(
|
future_next_token_ids = self.future_map.update_next_future(
|
||||||
cur_future_map_ct, bs
|
cur_future_map_ct, bs
|
||||||
)
|
)
|
||||||
return None, future_next_token_ids, False
|
return ForwardBatchOutput(
|
||||||
|
next_token_ids=future_next_token_ids,
|
||||||
|
can_run_cuda_graph=False,
|
||||||
|
)
|
||||||
|
|
||||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||||
|
|||||||
@@ -2,11 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from http import HTTPStatus
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -900,6 +900,17 @@ class ForwardBatch:
|
|||||||
return self.tbo_split_seq_index is not None
|
return self.tbo_split_seq_index is not None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ForwardBatchOutput:
|
||||||
|
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
|
||||||
|
# need to be more organized
|
||||||
|
logits_output: Optional[torch.Tensor] = None
|
||||||
|
next_token_ids: Optional[torch.Tensor] = None
|
||||||
|
num_accepted_tokens: Optional[int] = None
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None
|
||||||
|
can_run_cuda_graph: bool = False
|
||||||
|
|
||||||
|
|
||||||
def enable_num_token_non_padded(server_args):
|
def enable_num_token_non_padded(server_args):
|
||||||
return get_moe_expert_parallel_world_size() > 1
|
return get_moe_expert_parallel_world_size() > 1
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from sglang.srt.distributed import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
ScheduleBatch,
|
ScheduleBatch,
|
||||||
get_last_loc,
|
get_last_loc,
|
||||||
@@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
|||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
|
ForwardBatchOutput,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -422,9 +422,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
def draft_model_runner(self):
|
def draft_model_runner(self):
|
||||||
return self.model_runner
|
return self.model_runner
|
||||||
|
|
||||||
def forward_batch_speculative_generation(
|
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
|
||||||
self, batch: ScheduleBatch
|
|
||||||
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
|
|
||||||
"""Run speculative decoding forward.
|
"""Run speculative decoding forward.
|
||||||
|
|
||||||
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
||||||
@@ -437,14 +435,19 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||||
"""
|
"""
|
||||||
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
||||||
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
|
||||||
self.forward_target_extend(batch)
|
batch
|
||||||
)
|
)
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
self.forward_draft_extend(
|
self.forward_draft_extend(
|
||||||
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
||||||
)
|
)
|
||||||
return logits_output, next_token_ids, bid, 0, False
|
return ForwardBatchOutput(
|
||||||
|
logits_output=logits_output,
|
||||||
|
next_token_ids=next_token_ids,
|
||||||
|
num_accepted_tokens=0,
|
||||||
|
can_run_cuda_graph=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
spec_info = self.draft(batch)
|
spec_info = self.draft(batch)
|
||||||
@@ -462,12 +465,11 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# decode is not finished
|
# decode is not finished
|
||||||
self.forward_draft_extend_after_decode(batch)
|
self.forward_draft_extend_after_decode(batch)
|
||||||
|
|
||||||
return (
|
return ForwardBatchOutput(
|
||||||
logits_output,
|
logits_output=logits_output,
|
||||||
verify_output.verified_id,
|
next_token_ids=verify_output.verified_id,
|
||||||
model_worker_batch.bid,
|
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
|
||||||
sum(verify_output.accept_length_per_req_cpu),
|
can_run_cuda_graph=can_run_cuda_graph,
|
||||||
can_run_cuda_graph,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||||
@@ -499,19 +501,21 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
Returns:
|
Returns:
|
||||||
logits_output: The output of logits. It will contain the full hidden states.
|
logits_output: The output of logits. It will contain the full hidden states.
|
||||||
next_token_ids: Next token ids generated.
|
next_token_ids: Next token ids generated.
|
||||||
bid: The model batch ID. Used for overlap schedule.
|
|
||||||
"""
|
"""
|
||||||
# Forward with the target model and get hidden states.
|
# Forward with the target model and get hidden states.
|
||||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
|
logits_output, next_token_ids = (
|
||||||
|
forward_batch_output.logits_output,
|
||||||
|
forward_batch_output.next_token_ids,
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
logits_output,
|
logits_output,
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
model_worker_batch.bid,
|
|
||||||
model_worker_batch.seq_lens_cpu,
|
model_worker_batch.seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -811,10 +815,12 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
).cpu()
|
).cpu()
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
logits_output, _, can_run_cuda_graph = (
|
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||||
self.target_worker.forward_batch_generation(
|
model_worker_batch, skip_sample=True
|
||||||
model_worker_batch, skip_sample=True
|
)
|
||||||
)
|
logits_output, can_run_cuda_graph = (
|
||||||
|
forward_batch_output.logits_output,
|
||||||
|
forward_batch_output.can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
vocab_mask = None
|
vocab_mask = None
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
|
|||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
|
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
|
||||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||||
@@ -207,17 +207,18 @@ class NGRAMWorker:
|
|||||||
batch_tokens.append(put_ids)
|
batch_tokens.append(put_ids)
|
||||||
self.ngram_cache.batch_put(batch_tokens)
|
self.ngram_cache.batch_put(batch_tokens)
|
||||||
|
|
||||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
|
||||||
self._prepare_for_speculative_decoding(batch)
|
self._prepare_for_speculative_decoding(batch)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
bid = model_worker_batch.bid
|
|
||||||
num_accepted_tokens = 0
|
num_accepted_tokens = 0
|
||||||
|
|
||||||
if model_worker_batch.forward_mode.is_target_verify():
|
if model_worker_batch.forward_mode.is_target_verify():
|
||||||
logits_output, _, can_run_cuda_graph = (
|
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||||
self.target_worker.forward_batch_generation(
|
model_worker_batch, skip_sample=True
|
||||||
model_worker_batch, skip_sample=True
|
)
|
||||||
)
|
logits_output, can_run_cuda_graph = (
|
||||||
|
forward_batch_output.logits_output,
|
||||||
|
forward_batch_output.can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
verify_input = model_worker_batch.spec_info
|
verify_input = model_worker_batch.spec_info
|
||||||
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
||||||
@@ -227,14 +228,18 @@ class NGRAMWorker:
|
|||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||||
|
model_worker_batch
|
||||||
|
)
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
self.target_worker.forward_batch_generation(model_worker_batch)
|
forward_batch_output.logits_output,
|
||||||
|
forward_batch_output.next_token_ids,
|
||||||
|
forward_batch_output.can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return ForwardBatchOutput(
|
||||||
logits_output,
|
logits_output=logits_output,
|
||||||
next_token_ids,
|
next_token_ids=next_token_ids,
|
||||||
bid,
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
num_accepted_tokens,
|
can_run_cuda_graph=can_run_cuda_graph,
|
||||||
can_run_cuda_graph,
|
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user