Eagle speculative decoding part 3: small modifications to the general scheduler (#2709)
Co-authored-by: kavioyu <kavioyu@tencent.com>
This commit is contained in:
@@ -61,10 +61,10 @@ For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# node 1
|
# node 1
|
||||||
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --nccl-init 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
|
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
|
||||||
|
|
||||||
# node 2
|
# node 2
|
||||||
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --nccl-init 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
|
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
|
||||||
```
|
```
|
||||||
|
|
||||||
If you have two H100 nodes, the usage is similar to the aforementioned H20.
|
If you have two H100 nodes, the usage is similar to the aforementioned H20.
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
|
|||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server import _set_envs_and_config
|
from sglang.srt.server import _set_envs_and_config
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
|
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
|
||||||
|
|
||||||
|
|
||||||
@@ -214,6 +215,7 @@ def extend(reqs, model_runner):
|
|||||||
tree_cache=None,
|
tree_cache=None,
|
||||||
model_config=model_runner.model_config,
|
model_config=model_runner.model_config,
|
||||||
enable_overlap=False,
|
enable_overlap=False,
|
||||||
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend()
|
batch.prepare_for_extend()
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class AttentionBackend(ABC):
|
|||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
num_token: int,
|
num_tokens: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
num_token: int,
|
num_tokens: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
@@ -243,9 +243,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
"NHD",
|
"NHD",
|
||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
use_tensor_cores=self.decode_use_tensor_cores,
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1],
|
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
||||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||||
paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token],
|
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
||||||
|
:num_tokens
|
||||||
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
num_token: int,
|
num_tokens: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
|||||||
@@ -575,8 +575,8 @@ class ScheduleBatch:
|
|||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[SpecInfo] = None
|
spec_info: Optional[SpecInfo] = None
|
||||||
spec_algorithm: Optional[SpeculativeAlgorithm] = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
@@ -587,7 +587,7 @@ class ScheduleBatch:
|
|||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
speculative_algorithm: Optional[SpeculativeAlgorithm] = None,
|
spec_algorithm: SpeculativeAlgorithm,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -600,7 +600,7 @@ class ScheduleBatch:
|
|||||||
has_stream=any(req.stream for req in reqs),
|
has_stream=any(req.stream for req in reqs),
|
||||||
has_grammar=any(req.grammar for req in reqs),
|
has_grammar=any(req.grammar for req in reqs),
|
||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
spec_algorithm=speculative_algorithm,
|
spec_algorithm=spec_algorithm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1010,6 +1010,8 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
def prepare_for_decode(self):
|
def prepare_for_decode(self):
|
||||||
self.forward_mode = ForwardMode.DECODE
|
self.forward_mode = ForwardMode.DECODE
|
||||||
|
if self.spec_algorithm.is_eagle():
|
||||||
|
return
|
||||||
|
|
||||||
self.input_ids = self.output_ids
|
self.input_ids = self.output_ids
|
||||||
self.output_ids = None
|
self.output_ids = None
|
||||||
@@ -1172,6 +1174,7 @@ class ScheduleBatch:
|
|||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
decoding_reqs=self.decoding_reqs,
|
decoding_reqs=self.decoding_reqs,
|
||||||
|
spec_algorithm=self.spec_algorithm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -1232,8 +1235,8 @@ class ModelWorkerBatch:
|
|||||||
input_embeds: Optional[torch.tensor] = None
|
input_embeds: Optional[torch.tensor] = None
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[SpecInfo] = None
|
spec_info: Optional[SpecInfo] = None
|
||||||
spec_algorithm: Optional[SpeculativeAlgorithm] = None
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
|
|||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
@@ -116,6 +117,14 @@ class Scheduler:
|
|||||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||||
self.enable_metrics = server_args.enable_metrics
|
self.enable_metrics = server_args.enable_metrics
|
||||||
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
|
server_args.speculative_algorithm
|
||||||
|
)
|
||||||
|
self.decode_mem_cache_buf_multiplier = (
|
||||||
|
self.server_args.speculative_num_draft_tokens
|
||||||
|
if not self.spec_algorithm.is_none()
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
@@ -199,6 +208,21 @@ class Scheduler:
|
|||||||
nccl_port=port_args.nccl_port,
|
nccl_port=port_args.nccl_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Launch worker for speculative decoding if need
|
||||||
|
if self.spec_algorithm.is_eagle():
|
||||||
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||||
|
|
||||||
|
self.draft_worker = EAGLEWorker(
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
server_args=server_args,
|
||||||
|
nccl_port=port_args.nccl_port,
|
||||||
|
target_worker=self.tp_worker,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.draft_worker = None
|
||||||
|
|
||||||
# Get token and memory info from the model worker
|
# Get token and memory info from the model worker
|
||||||
(
|
(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
@@ -855,6 +879,7 @@ class Scheduler:
|
|||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
|
self.spec_algorithm,
|
||||||
)
|
)
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
@@ -888,11 +913,15 @@ class Scheduler:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
||||||
|
test_retract and batch.batch_size() > 10
|
||||||
|
):
|
||||||
old_ratio = self.new_token_ratio
|
old_ratio = self.new_token_ratio
|
||||||
|
|
||||||
retracted_reqs, new_token_ratio = batch.retract_decode()
|
retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||||
self.new_token_ratio = new_token_ratio
|
self.new_token_ratio = new_token_ratio
|
||||||
|
if self.draft_worker:
|
||||||
|
self.draft_worker.finish_request(retracted_reqs)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Decode out of memory happened. "
|
"Decode out of memory happened. "
|
||||||
@@ -926,11 +955,17 @@ class Scheduler:
|
|||||||
self.forward_ct += 1
|
self.forward_ct += 1
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
|
||||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
if self.spec_algorithm.is_none():
|
||||||
model_worker_batch
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
)
|
logits_output, next_token_ids = (
|
||||||
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits_output, next_token_ids, model_worker_batch, spec_info = (
|
||||||
|
self.draft_worker.forward_batch_speculative_generation(batch)
|
||||||
|
)
|
||||||
|
batch.spec_info = spec_info
|
||||||
elif batch.forward_mode.is_idle():
|
elif batch.forward_mode.is_idle():
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
self.tp_worker.forward_batch_idle(model_worker_batch)
|
self.tp_worker.forward_batch_idle(model_worker_batch)
|
||||||
@@ -1077,7 +1112,10 @@ class Scheduler:
|
|||||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
req.output_ids.append(next_token_id)
|
if batch.spec_algorithm.is_none():
|
||||||
|
# speculative worker will solve the output_ids in speculative decoding
|
||||||
|
req.output_ids.append(next_token_id)
|
||||||
|
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
@@ -1252,6 +1290,9 @@ class Scheduler:
|
|||||||
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
||||||
or (not req.stream and len(req.output_ids) % 50 == 0)
|
or (not req.stream and len(req.output_ids) % 50 == 0)
|
||||||
):
|
):
|
||||||
|
if self.draft_worker and req.finished():
|
||||||
|
self.draft_worker.finish_request(req)
|
||||||
|
|
||||||
rids.append(req.rid)
|
rids.append(req.rid)
|
||||||
finished_reasons.append(
|
finished_reasons.append(
|
||||||
req.finished_reason.to_json() if req.finished_reason else None
|
req.finished_reason.to_json() if req.finished_reason else None
|
||||||
@@ -1383,6 +1424,7 @@ class Scheduler:
|
|||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
|
self.spec_algorithm,
|
||||||
)
|
)
|
||||||
idle_batch.prepare_for_idle()
|
idle_batch.prepare_for_idle()
|
||||||
return idle_batch
|
return idle_batch
|
||||||
|
|||||||
@@ -45,13 +45,18 @@ class TpModelWorker:
|
|||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
dp_rank: Optional[int],
|
dp_rank: Optional[int],
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
|
is_draft_worker: bool = False,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
server_args.model_path,
|
(
|
||||||
|
server_args.model_path
|
||||||
|
if not is_draft_worker
|
||||||
|
else server_args.speculative_draft_model_path
|
||||||
|
),
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
revision=server_args.revision,
|
revision=server_args.revision,
|
||||||
context_length=server_args.context_length,
|
context_length=server_args.context_length,
|
||||||
@@ -68,6 +73,7 @@ class TpModelWorker:
|
|||||||
tp_size=server_args.tp_size,
|
tp_size=server_args.tp_size,
|
||||||
nccl_port=nccl_port,
|
nccl_port=nccl_port,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
|
is_draft_worker=is_draft_worker,
|
||||||
)
|
)
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = self.processor = None
|
self.tokenizer = self.processor = None
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
@@ -106,11 +106,6 @@ def set_torch_compile_config():
|
|||||||
torch._dynamo.config.cache_size_limit = 1024
|
torch._dynamo.config.cache_size_limit = 1024
|
||||||
|
|
||||||
|
|
||||||
@maybe_torch_compile(dynamic=True)
|
|
||||||
def clamp_position(seq_lens):
|
|
||||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
|
||||||
|
|
||||||
|
|
||||||
class CudaGraphRunner:
|
class CudaGraphRunner:
|
||||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||||
|
|
||||||
@@ -157,6 +152,17 @@ class CudaGraphRunner:
|
|||||||
self.capture_forward_mode = ForwardMode.DECODE
|
self.capture_forward_mode = ForwardMode.DECODE
|
||||||
self.num_tokens_per_bs = 1
|
self.num_tokens_per_bs = 1
|
||||||
|
|
||||||
|
if model_runner.spec_algorithm.is_eagle():
|
||||||
|
if self.model_runner.is_draft_worker:
|
||||||
|
self.num_tokens_per_bs = (
|
||||||
|
self.model_runner.server_args.speculative_eagle_topk
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
|
self.num_tokens_per_bs = (
|
||||||
|
self.model_runner.server_args.speculative_num_draft_tokens
|
||||||
|
)
|
||||||
|
|
||||||
self.compile_bs = (
|
self.compile_bs = (
|
||||||
[
|
[
|
||||||
bs
|
bs
|
||||||
@@ -192,6 +198,13 @@ class CudaGraphRunner:
|
|||||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
||||||
|
|
||||||
|
# Speculative_inference
|
||||||
|
if model_runner.spec_algorithm.is_eagle():
|
||||||
|
self.hidden_states = torch.zeros(
|
||||||
|
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
||||||
|
dtype=self.model_runner.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
||||||
self.encoder_lens = torch.full(
|
self.encoder_lens = torch.full(
|
||||||
@@ -234,9 +247,6 @@ class CudaGraphRunner:
|
|||||||
self.model_runner.model.capture_mode = False
|
self.model_runner.model.capture_mode = False
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if not forward_batch.forward_mode.is_cuda_graph():
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
||||||
forward_batch.global_num_tokens
|
forward_batch.global_num_tokens
|
||||||
@@ -291,21 +301,18 @@ class CudaGraphRunner:
|
|||||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
num_token = bs * self.num_tokens_per_bs
|
num_tokens = bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
input_ids = self.input_ids[:num_token]
|
input_ids = self.input_ids[:num_tokens]
|
||||||
req_pool_indices = self.req_pool_indices[:bs]
|
req_pool_indices = self.req_pool_indices[:bs]
|
||||||
seq_lens = self.seq_lens[:bs]
|
seq_lens = self.seq_lens[:bs]
|
||||||
out_cache_loc = self.out_cache_loc[:num_token]
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||||
positions = self.positions[:num_token]
|
positions = self.positions[:num_tokens]
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
encoder_lens = self.encoder_lens[:bs]
|
encoder_lens = self.encoder_lens[:bs]
|
||||||
else:
|
else:
|
||||||
encoder_lens = None
|
encoder_lens = None
|
||||||
|
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
|
||||||
mrope_positions = self.mrope_positions[:, :bs]
|
mrope_positions = self.mrope_positions[:, :bs]
|
||||||
|
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
@@ -325,20 +332,22 @@ class CudaGraphRunner:
|
|||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
attn_backend=self.model_runner.attn_backend,
|
attn_backend=self.model_runner.attn_backend,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
seq_lens_sum=seq_lens_sum,
|
seq_lens_sum=seq_lens.sum(),
|
||||||
encoder_lens=encoder_lens,
|
encoder_lens=encoder_lens,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=[0] * num_token,
|
top_logprobs_nums=[0] * bs,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
global_num_tokens=global_num_tokens,
|
global_num_tokens=global_num_tokens,
|
||||||
mrope_positions=mrope_positions,
|
mrope_positions=mrope_positions,
|
||||||
gathered_buffer=gathered_buffer,
|
gathered_buffer=gathered_buffer,
|
||||||
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
|
spec_info=self.get_spec_info(num_tokens, positions),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
num_token,
|
num_tokens,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
encoder_lens,
|
encoder_lens,
|
||||||
@@ -394,14 +403,16 @@ class CudaGraphRunner:
|
|||||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||||
positions = clamp_position(forward_batch.seq_lens)
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||||
self.positions[:raw_num_token].copy_(positions)
|
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
if forward_batch.mrope_positions is not None:
|
if forward_batch.mrope_positions is not None:
|
||||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||||
|
|
||||||
|
if hasattr(forward_batch.spec_info, "hidden_states"):
|
||||||
|
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
@@ -424,3 +435,36 @@ class CudaGraphRunner:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
return logits_output
|
return logits_output
|
||||||
|
|
||||||
|
def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
|
||||||
|
spec_info = None
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle():
|
||||||
|
from sglang.srt.speculative.eagle_utils import (
|
||||||
|
EAGLEDraftInput,
|
||||||
|
EagleVerifyInput,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_runner.is_draft_worker:
|
||||||
|
spec_info = EAGLEDraftInput()
|
||||||
|
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
||||||
|
spec_info.positions = positions
|
||||||
|
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
spec_info.init(self.model_runner.server_args)
|
||||||
|
else:
|
||||||
|
spec_info = EagleVerifyInput(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
self.model_runner.server_args.speculative_num_draft_tokens,
|
||||||
|
)
|
||||||
|
spec_info.custom_mask = torch.zeros(
|
||||||
|
(num_tokens * self.model_runner.model_config.context_len),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
|
||||||
|
return spec_info
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||||
|
from sglang.srt.utils import maybe_torch_compile
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
@@ -276,10 +277,21 @@ class ForwardBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if ret.forward_mode.is_idle():
|
if ret.forward_mode.is_idle():
|
||||||
|
ret.positions = torch.empty((0,), device=device)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
# Override the positions with spec_info
|
||||||
|
if (
|
||||||
|
ret.spec_info is not None
|
||||||
|
and getattr(ret.spec_info, "positions", None) is not None
|
||||||
|
):
|
||||||
|
ret.positions = ret.spec_info.positions
|
||||||
|
|
||||||
# Init position information
|
# Init position information
|
||||||
if not ret.forward_mode.is_decode():
|
if ret.forward_mode.is_decode():
|
||||||
|
if ret.positions is None:
|
||||||
|
ret.positions = clamp_position(batch.seq_lens)
|
||||||
|
else:
|
||||||
ret.extend_seq_lens = torch.tensor(
|
ret.extend_seq_lens = torch.tensor(
|
||||||
batch.extend_seq_lens, dtype=torch.int32
|
batch.extend_seq_lens, dtype=torch.int32
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
@@ -288,13 +300,15 @@ class ForwardBatch:
|
|||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
if model_runner.server_args.attention_backend != "torch_native":
|
if model_runner.server_args.attention_backend != "torch_native":
|
||||||
ret.extend_num_tokens = batch.extend_num_tokens
|
ret.extend_num_tokens = batch.extend_num_tokens
|
||||||
ret.positions, ret.extend_start_loc = compute_position_triton(
|
positions, ret.extend_start_loc = compute_position_triton(
|
||||||
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ret.positions, ret.extend_start_loc = compute_position_torch(
|
positions, ret.extend_start_loc = compute_position_torch(
|
||||||
ret.extend_prefix_lens, ret.extend_seq_lens
|
ret.extend_prefix_lens, ret.extend_seq_lens
|
||||||
)
|
)
|
||||||
|
if ret.positions is None:
|
||||||
|
ret.positions = positions
|
||||||
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
||||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||||
@@ -383,6 +397,11 @@ def compute_position_torch(
|
|||||||
return positions.to(torch.int64), extend_start_loc
|
return positions.to(torch.int64), extend_start_loc
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_torch_compile(dynamic=True)
|
||||||
|
def clamp_position(seq_lens):
|
||||||
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||||
|
|
||||||
|
|
||||||
class CaptureHiddenMode(IntEnum):
|
class CaptureHiddenMode(IntEnum):
|
||||||
NULL = auto()
|
NULL = auto()
|
||||||
FULL = auto()
|
FULL = auto()
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
@@ -74,6 +75,7 @@ class ModelRunner:
|
|||||||
tp_size: int,
|
tp_size: int,
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
|
is_draft_worker: bool = False,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
@@ -84,8 +86,12 @@ class ModelRunner:
|
|||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.dist_port = nccl_port
|
self.dist_port = nccl_port
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
|
self.is_draft_worker = is_draft_worker
|
||||||
self.is_generation = model_config.is_generation
|
self.is_generation = model_config.is_generation
|
||||||
self.is_multimodal = model_config.is_multimodal
|
self.is_multimodal = model_config.is_multimodal
|
||||||
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
|
server_args.speculative_algorithm
|
||||||
|
)
|
||||||
|
|
||||||
# Model-specific adjustment
|
# Model-specific adjustment
|
||||||
if (
|
if (
|
||||||
@@ -205,14 +211,18 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
||||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||||
init_distributed_environment(
|
|
||||||
backend=backend,
|
if not self.is_draft_worker:
|
||||||
world_size=self.tp_size,
|
# Only initilzie the distributed environment on the target model worker.
|
||||||
rank=self.tp_rank,
|
init_distributed_environment(
|
||||||
local_rank=self.gpu_id,
|
backend=backend,
|
||||||
distributed_init_method=dist_init_method,
|
world_size=self.tp_size,
|
||||||
)
|
rank=self.tp_rank,
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
local_rank=self.gpu_id,
|
||||||
|
distributed_init_method=dist_init_method,
|
||||||
|
)
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||||
|
|
||||||
min_per_gpu_memory = get_available_gpu_memory(
|
min_per_gpu_memory = get_available_gpu_memory(
|
||||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||||
)
|
)
|
||||||
@@ -407,7 +417,6 @@ class ModelRunner:
|
|||||||
target_dtype = (
|
target_dtype = (
|
||||||
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
||||||
)
|
)
|
||||||
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
self._model_update_group is not None
|
self._model_update_group is not None
|
||||||
@@ -506,6 +515,28 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||||
|
|
||||||
|
if max_num_reqs is None:
|
||||||
|
max_num_reqs = min(
|
||||||
|
max(
|
||||||
|
int(
|
||||||
|
self.max_total_num_tokens / self.model_config.context_len * 512
|
||||||
|
),
|
||||||
|
2048,
|
||||||
|
),
|
||||||
|
4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.spec_algorithm.is_none():
|
||||||
|
if self.is_draft_worker:
|
||||||
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||||
|
else:
|
||||||
|
self.server_args.draft_runner_cache_size = (
|
||||||
|
self.max_total_num_tokens
|
||||||
|
+ max_num_reqs * self.server_args.speculative_num_steps
|
||||||
|
+ 100
|
||||||
|
)
|
||||||
|
|
||||||
if max_total_tokens is not None:
|
if max_total_tokens is not None:
|
||||||
if max_total_tokens > self.max_total_num_tokens:
|
if max_total_tokens > self.max_total_num_tokens:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@@ -520,17 +551,6 @@ class ModelRunner:
|
|||||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||||
)
|
)
|
||||||
|
|
||||||
if max_num_reqs is None:
|
|
||||||
max_num_reqs = min(
|
|
||||||
max(
|
|
||||||
int(
|
|
||||||
self.max_total_num_tokens / self.model_config.context_len * 512
|
|
||||||
),
|
|
||||||
2048,
|
|
||||||
),
|
|
||||||
4096,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
size=max_num_reqs + 1,
|
size=max_num_reqs + 1,
|
||||||
max_context_len=self.model_config.context_len + 4,
|
max_context_len=self.model_config.context_len + 4,
|
||||||
@@ -650,10 +670,6 @@ class ModelRunner:
|
|||||||
tensor_parallel(self.model, device_mesh)
|
tensor_parallel(self.model, device_mesh)
|
||||||
|
|
||||||
def forward_decode(self, forward_batch: ForwardBatch):
|
def forward_decode(self, forward_batch: ForwardBatch):
|
||||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
|
||||||
return self.cuda_graph_runner.replay(forward_batch)
|
|
||||||
|
|
||||||
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
|
||||||
self.attn_backend.init_forward_metadata(forward_batch)
|
self.attn_backend.init_forward_metadata(forward_batch)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
@@ -683,14 +699,18 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward_idle(self, forward_batch: ForwardBatch):
|
def forward_idle(self, forward_batch: ForwardBatch):
|
||||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
|
||||||
return self.cuda_graph_runner.replay(forward_batch)
|
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||||
|
if (
|
||||||
|
forward_batch.forward_mode.is_cuda_graph()
|
||||||
|
and self.cuda_graph_runner
|
||||||
|
and self.cuda_graph_runner.can_run(forward_batch)
|
||||||
|
):
|
||||||
|
return self.cuda_graph_runner.replay(forward_batch)
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
return self.forward_decode(forward_batch)
|
return self.forward_decode(forward_batch)
|
||||||
elif forward_batch.forward_mode.is_extend():
|
elif forward_batch.forward_mode.is_extend():
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||||
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_amdgpu_memory_capacity,
|
get_amdgpu_memory_capacity,
|
||||||
get_hpu_memory_capacity,
|
get_hpu_memory_capacity,
|
||||||
@@ -247,6 +248,17 @@ class ServerArgs:
|
|||||||
"Overlap scheduler is disabled."
|
"Overlap scheduler is disabled."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Speculative Decoding
|
||||||
|
if self.speculative_algorithm == "EAGLE":
|
||||||
|
self.prefill_only_one_req = True
|
||||||
|
self.disable_cuda_graph_padding = True
|
||||||
|
self.disable_radix_cache = True
|
||||||
|
self.disable_overlap_schedule = True
|
||||||
|
self.chunked_prefill_size = -1
|
||||||
|
logger.info(
|
||||||
|
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
|
||||||
|
)
|
||||||
|
|
||||||
# GGUF
|
# GGUF
|
||||||
if (
|
if (
|
||||||
self.load_format == "auto" or self.load_format == "gguf"
|
self.load_format == "auto" or self.load_format == "gguf"
|
||||||
|
|||||||
@@ -2,8 +2,12 @@ from enum import IntEnum, auto
|
|||||||
|
|
||||||
|
|
||||||
class SpeculativeAlgorithm(IntEnum):
|
class SpeculativeAlgorithm(IntEnum):
|
||||||
|
NONE = auto()
|
||||||
EAGLE = auto()
|
EAGLE = auto()
|
||||||
|
|
||||||
|
def is_none(self):
|
||||||
|
return self == SpeculativeAlgorithm.NONE
|
||||||
|
|
||||||
def is_eagle(self):
|
def is_eagle(self):
|
||||||
return self == SpeculativeAlgorithm.EAGLE
|
return self == SpeculativeAlgorithm.EAGLE
|
||||||
|
|
||||||
@@ -11,6 +15,7 @@ class SpeculativeAlgorithm(IntEnum):
|
|||||||
def from_string(name: str):
|
def from_string(name: str):
|
||||||
name_map = {
|
name_map = {
|
||||||
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||||
|
None: SpeculativeAlgorithm.NONE,
|
||||||
}
|
}
|
||||||
return name_map[name]
|
return name_map[name]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user