Unify forward output datastructure (#11124)
This commit is contained in:
@@ -22,6 +22,7 @@ from typing import List, Optional, Set, Union
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_config,
|
||||
get_context_length,
|
||||
@@ -31,7 +32,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, retry
|
||||
from sglang.srt.utils import is_hip, retry
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -237,7 +238,7 @@ class ModelConfig:
|
||||
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
|
||||
)
|
||||
if (
|
||||
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
|
||||
envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.get()
|
||||
or is_in_ci() # FIXME: fix this special case
|
||||
):
|
||||
logger.warning(msg)
|
||||
|
||||
@@ -689,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
bids = [None] * self.pp_size
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
|
||||
# Either success or failed
|
||||
@@ -761,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids, bids[mb_id] = (
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
next_token_ids = result.next_token_ids
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
@@ -801,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=None,
|
||||
extend_logprob_start_len_per_req=None,
|
||||
bid=bids[next_mb_id],
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
self.process_batch_result_disagg_prefill(
|
||||
@@ -818,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
|
||||
# carry the outputs to the next stage
|
||||
if not self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
bids[mb_id] = result.bid
|
||||
if pp_outputs:
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
@@ -838,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
# FIXME(lsyin): remove this assert
|
||||
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors,
|
||||
result.pp_hidden_states_proxy_tensors.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
|
||||
@@ -860,10 +860,6 @@ class Req:
|
||||
)
|
||||
|
||||
|
||||
# Batch id
|
||||
bid = 0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
"""Store all information of a batch on the scheduler."""
|
||||
@@ -1829,10 +1825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
|
||||
)
|
||||
|
||||
global bid
|
||||
bid += 1
|
||||
return ModelWorkerBatch(
|
||||
bid=bid,
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
@@ -1952,8 +1945,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelWorkerBatch:
|
||||
# The batch id
|
||||
bid: int
|
||||
# The forward mode
|
||||
forward_mode: ForwardMode
|
||||
# The input ids
|
||||
|
||||
@@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatchOutput,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
@@ -175,7 +179,6 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
get_zmq_socket,
|
||||
is_cpu,
|
||||
kill_itself_when_parent_died,
|
||||
numa_bind_to_node,
|
||||
point_to_point_pyobj,
|
||||
@@ -194,24 +197,59 @@ logger = logging.getLogger(__name__)
|
||||
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationBatchResult:
|
||||
logits_output: Optional[LogitsProcessorOutput]
|
||||
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
|
||||
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
||||
next_token_ids: Optional[List[int]]
|
||||
can_run_cuda_graph: bool
|
||||
|
||||
# For output processing
|
||||
extend_input_len_per_req: List[int]
|
||||
extend_logprob_start_len_per_req: List[int]
|
||||
bid: int
|
||||
can_run_cuda_graph: bool
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch_output(
|
||||
cls,
|
||||
forward_batch_output: ForwardBatchOutput,
|
||||
extend_input_len_per_req: List[int],
|
||||
extend_logprob_start_len_per_req: List[int],
|
||||
):
|
||||
# TODO(lsyin): remove this workaround logic and try to unify output classes
|
||||
|
||||
return cls(
|
||||
logits_output=forward_batch_output.logits_output,
|
||||
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
|
||||
next_token_ids=forward_batch_output.next_token_ids,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pp_proxy(
|
||||
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
||||
):
|
||||
# TODO(lsyin): also simplify this logic
|
||||
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
||||
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
||||
proxy_dict = next_pp_outputs.tensors
|
||||
return cls(
|
||||
logits_output=logits_output,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
||||
extend_logprob_start_len_per_req=proxy_dict.get(
|
||||
"extend_logprob_start_len_per_req", None
|
||||
),
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingBatchResult:
|
||||
embeddings: torch.Tensor
|
||||
bid: int
|
||||
|
||||
|
||||
class Scheduler(
|
||||
@@ -403,6 +441,12 @@ class Scheduler(
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
# Dispatch the model worker
|
||||
if self.spec_algorithm.is_none():
|
||||
self.model_worker = self.tp_worker
|
||||
else:
|
||||
self.model_worker = self.draft_worker
|
||||
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
self.max_total_num_tokens,
|
||||
@@ -959,7 +1003,6 @@ class Scheduler(
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
bids = [None] * self.pp_size
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
while True:
|
||||
server_is_idle = True
|
||||
@@ -980,10 +1023,7 @@ class Scheduler(
|
||||
# (last rank) send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids, bids[mb_id] = (
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
next_token_ids = result.next_token_ids
|
||||
if self.cur_batch.return_logprob:
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
@@ -1031,17 +1071,10 @@ class Scheduler(
|
||||
logits_output = LogitsProcessorOutput(**logits_output_args)
|
||||
else:
|
||||
logits_output = None
|
||||
output_result = GenerationBatchResult(
|
||||
|
||||
output_result = GenerationBatchResult.from_pp_proxy(
|
||||
logits_output=logits_output,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
||||
"extend_input_len_per_req", None
|
||||
),
|
||||
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
||||
"extend_logprob_start_len_per_req", None
|
||||
),
|
||||
bid=bids[next_mb_id],
|
||||
next_pp_outputs=next_pp_outputs,
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||
@@ -1049,8 +1082,6 @@ class Scheduler(
|
||||
|
||||
# (not last rank)
|
||||
if not self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
bids[mb_id] = result.bid
|
||||
# carry the outputs to the next stage
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
if pp_outputs:
|
||||
@@ -1072,8 +1103,10 @@ class Scheduler(
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
# FIXME(lsyin): remove this assert
|
||||
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors,
|
||||
result.pp_hidden_states_proxy_tensors.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
@@ -2016,33 +2049,25 @@ class Scheduler(
|
||||
|
||||
# Run forward
|
||||
if self.is_generation:
|
||||
|
||||
batch_or_worker_batch = batch
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
# FIXME(lsyin): remove this if and finally unify the abstraction
|
||||
batch_or_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
else:
|
||||
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
bid = model_worker_batch.bid
|
||||
else:
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
bid,
|
||||
num_accepted_tokens,
|
||||
can_run_cuda_graph,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
bs = batch.batch_size()
|
||||
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
||||
self.spec_num_total_forward_ct += bs
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
forward_batch_output = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
batch.output_ids = next_token_ids
|
||||
if not self.spec_algorithm.is_none():
|
||||
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
||||
self.udpate_spec_metrics(
|
||||
batch.batch_size(), forward_batch_output.num_accepted_tokens
|
||||
)
|
||||
|
||||
# update batch's output ids
|
||||
batch.output_ids = forward_batch_output.next_token_ids
|
||||
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
@@ -2051,6 +2076,7 @@ class Scheduler(
|
||||
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
||||
else:
|
||||
extend_input_len_per_req = None
|
||||
|
||||
if batch.return_logprob:
|
||||
extend_logprob_start_len_per_req = [
|
||||
req.extend_logprob_start_len for req in batch.reqs
|
||||
@@ -2058,25 +2084,15 @@ class Scheduler(
|
||||
else:
|
||||
extend_logprob_start_len_per_req = None
|
||||
|
||||
ret = GenerationBatchResult(
|
||||
logits_output=logits_output if self.pp_group.is_last_rank else None,
|
||||
pp_hidden_states_proxy_tensors=(
|
||||
pp_hidden_states_proxy_tensors
|
||||
if not self.pp_group.is_last_rank
|
||||
else None
|
||||
),
|
||||
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
||||
return GenerationBatchResult.from_forward_batch_output(
|
||||
forward_batch_output=forward_batch_output,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
bid=bid,
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
else: # embedding or reward model
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = EmbeddingBatchResult(
|
||||
embeddings=embeddings, bid=model_worker_batch.bid
|
||||
)
|
||||
ret = EmbeddingBatchResult(embeddings=embeddings)
|
||||
return ret
|
||||
|
||||
def process_batch_result(
|
||||
|
||||
@@ -80,6 +80,11 @@ class SchedulerMetricsMixin:
|
||||
kv_events_config, self.attn_dp_rank
|
||||
)
|
||||
|
||||
def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
|
||||
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
||||
self.spec_num_total_forward_ct += bs
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
|
||||
def log_prefill_stats(
|
||||
self: Scheduler,
|
||||
adder: PrefillAdder,
|
||||
|
||||
@@ -173,8 +173,7 @@ class SchedulerOutputProcessorMixin:
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
else: # embedding or reward model
|
||||
embeddings, bid = result.embeddings, result.bid
|
||||
embeddings = embeddings.tolist()
|
||||
embeddings = result.embeddings.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
|
||||
@@ -43,7 +43,11 @@ from sglang.srt.managers.io_struct import (
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardBatchOutput,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -234,9 +238,7 @@ class TpModelWorker:
|
||||
model_worker_batch: ModelWorkerBatch,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
skip_sample: bool = False,
|
||||
) -> Tuple[
|
||||
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
||||
]:
|
||||
) -> ForwardBatchOutput:
|
||||
# update the consumer index of hicache to the running batch
|
||||
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
||||
|
||||
@@ -271,13 +273,20 @@ class TpModelWorker:
|
||||
else:
|
||||
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
||||
|
||||
return logits_output, next_token_ids, can_run_cuda_graph
|
||||
return ForwardBatchOutput(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=next_token_ids,
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
else:
|
||||
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
||||
forward_batch,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
return pp_proxy_tensors.tensors, None, can_run_cuda_graph
|
||||
return ForwardBatchOutput(
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
|
||||
@@ -39,6 +39,7 @@ from sglang.srt.managers.io_struct import (
|
||||
from sglang.srt.managers.overlap_utils import FutureMap
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import DynamicGradMode
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -160,13 +161,17 @@ class TpModelWorkerClient:
|
||||
self.future_map.resolve_future(model_worker_batch)
|
||||
|
||||
# Run forward
|
||||
forward_batch_output = self.worker.forward_batch_generation(
|
||||
model_worker_batch,
|
||||
model_worker_batch.launch_done,
|
||||
# Skip sampling for prefill-only requests
|
||||
skip_sample=model_worker_batch.is_prefill_only,
|
||||
)
|
||||
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.worker.forward_batch_generation(
|
||||
model_worker_batch,
|
||||
model_worker_batch.launch_done,
|
||||
# Skip sampling for prefill-only requests
|
||||
skip_sample=model_worker_batch.is_prefill_only,
|
||||
)
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.next_token_ids,
|
||||
forward_batch_output.can_run_cuda_graph,
|
||||
)
|
||||
|
||||
# Update the future token ids map
|
||||
@@ -227,7 +232,7 @@ class TpModelWorkerClient:
|
||||
|
||||
def forward_batch_generation(
|
||||
self, model_worker_batch: ModelWorkerBatch
|
||||
) -> Tuple[None, torch.Tensor, bool]:
|
||||
) -> ForwardBatchOutput:
|
||||
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
||||
sampling_info = model_worker_batch.sampling_info
|
||||
sampling_info.update_penalties()
|
||||
@@ -250,7 +255,10 @@ class TpModelWorkerClient:
|
||||
future_next_token_ids = self.future_map.update_next_future(
|
||||
cur_future_map_ct, bs
|
||||
)
|
||||
return None, future_next_token_ids, False
|
||||
return ForwardBatchOutput(
|
||||
next_token_ids=future_next_token_ids,
|
||||
can_run_cuda_graph=False,
|
||||
)
|
||||
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||
|
||||
@@ -2,11 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -900,6 +900,17 @@ class ForwardBatch:
|
||||
return self.tbo_split_seq_index is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardBatchOutput:
|
||||
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
|
||||
# need to be more organized
|
||||
logits_output: Optional[torch.Tensor] = None
|
||||
next_token_ids: Optional[torch.Tensor] = None
|
||||
num_accepted_tokens: Optional[int] = None
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None
|
||||
can_run_cuda_graph: bool = False
|
||||
|
||||
|
||||
def enable_num_token_non_padded(server_args):
|
||||
return get_moe_expert_parallel_world_size() > 1
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
ScheduleBatch,
|
||||
get_last_loc,
|
||||
@@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardBatchOutput,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -422,9 +422,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
def draft_model_runner(self):
|
||||
return self.model_runner
|
||||
|
||||
def forward_batch_speculative_generation(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
|
||||
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
|
||||
"""Run speculative decoding forward.
|
||||
|
||||
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
||||
@@ -437,14 +435,19 @@ class EAGLEWorker(TpModelWorker):
|
||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||
"""
|
||||
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
||||
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
||||
self.forward_target_extend(batch)
|
||||
logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
|
||||
batch
|
||||
)
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend(
|
||||
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
||||
)
|
||||
return logits_output, next_token_ids, bid, 0, False
|
||||
return ForwardBatchOutput(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=next_token_ids,
|
||||
num_accepted_tokens=0,
|
||||
can_run_cuda_graph=False,
|
||||
)
|
||||
else:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
@@ -462,12 +465,11 @@ class EAGLEWorker(TpModelWorker):
|
||||
# decode is not finished
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
|
||||
return (
|
||||
logits_output,
|
||||
verify_output.verified_id,
|
||||
model_worker_batch.bid,
|
||||
sum(verify_output.accept_length_per_req_cpu),
|
||||
can_run_cuda_graph,
|
||||
return ForwardBatchOutput(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=verify_output.verified_id,
|
||||
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
@@ -499,19 +501,21 @@ class EAGLEWorker(TpModelWorker):
|
||||
Returns:
|
||||
logits_output: The output of logits. It will contain the full hidden states.
|
||||
next_token_ids: Next token ids generated.
|
||||
bid: The model batch ID. Used for overlap schedule.
|
||||
"""
|
||||
# Forward with the target model and get hidden states.
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
logits_output, next_token_ids = (
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.next_token_ids,
|
||||
)
|
||||
return (
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
model_worker_batch.bid,
|
||||
model_worker_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
@@ -811,10 +815,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
).cpu()
|
||||
|
||||
# Forward
|
||||
logits_output, _, can_run_cuda_graph = (
|
||||
self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
logits_output, can_run_cuda_graph = (
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.can_run_cuda_graph,
|
||||
)
|
||||
|
||||
vocab_mask = None
|
||||
|
||||
@@ -7,7 +7,7 @@ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
|
||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||
@@ -207,17 +207,18 @@ class NGRAMWorker:
|
||||
batch_tokens.append(put_ids)
|
||||
self.ngram_cache.batch_put(batch_tokens)
|
||||
|
||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
|
||||
self._prepare_for_speculative_decoding(batch)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
bid = model_worker_batch.bid
|
||||
num_accepted_tokens = 0
|
||||
|
||||
if model_worker_batch.forward_mode.is_target_verify():
|
||||
logits_output, _, can_run_cuda_graph = (
|
||||
self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
logits_output, can_run_cuda_graph = (
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.can_run_cuda_graph,
|
||||
)
|
||||
verify_input = model_worker_batch.spec_info
|
||||
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
||||
@@ -227,14 +228,18 @@ class NGRAMWorker:
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
|
||||
else:
|
||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.target_worker.forward_batch_generation(model_worker_batch)
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.next_token_ids,
|
||||
forward_batch_output.can_run_cuda_graph,
|
||||
)
|
||||
|
||||
return (
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
bid,
|
||||
num_accepted_tokens,
|
||||
can_run_cuda_graph,
|
||||
return ForwardBatchOutput(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=next_token_ids,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user