Eagle speculative decoding part 3: small modifications to the general scheduler (#2709)
Co-authored-by: kavioyu <kavioyu@tencent.com>
This commit is contained in:
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
@@ -74,6 +75,7 @@ class ModelRunner:
|
||||
tp_size: int,
|
||||
nccl_port: int,
|
||||
server_args: ServerArgs,
|
||||
is_draft_worker: bool = False,
|
||||
):
|
||||
# Parse args
|
||||
self.model_config = model_config
|
||||
@@ -84,8 +86,12 @@ class ModelRunner:
|
||||
self.tp_size = tp_size
|
||||
self.dist_port = nccl_port
|
||||
self.server_args = server_args
|
||||
self.is_draft_worker = is_draft_worker
|
||||
self.is_generation = model_config.is_generation
|
||||
self.is_multimodal = model_config.is_multimodal
|
||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
|
||||
# Model-specific adjustment
|
||||
if (
|
||||
@@ -205,14 +211,18 @@ class ModelRunner:
|
||||
else:
|
||||
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||
init_distributed_environment(
|
||||
backend=backend,
|
||||
world_size=self.tp_size,
|
||||
rank=self.tp_rank,
|
||||
local_rank=self.gpu_id,
|
||||
distributed_init_method=dist_init_method,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
|
||||
if not self.is_draft_worker:
|
||||
# Only initilzie the distributed environment on the target model worker.
|
||||
init_distributed_environment(
|
||||
backend=backend,
|
||||
world_size=self.tp_size,
|
||||
rank=self.tp_rank,
|
||||
local_rank=self.gpu_id,
|
||||
distributed_init_method=dist_init_method,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
|
||||
min_per_gpu_memory = get_available_gpu_memory(
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
@@ -407,7 +417,6 @@ class ModelRunner:
|
||||
target_dtype = (
|
||||
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
||||
)
|
||||
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
|
||||
|
||||
assert (
|
||||
self._model_update_group is not None
|
||||
@@ -506,6 +515,28 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||
|
||||
if max_num_reqs is None:
|
||||
max_num_reqs = min(
|
||||
max(
|
||||
int(
|
||||
self.max_total_num_tokens / self.model_config.context_len * 512
|
||||
),
|
||||
2048,
|
||||
),
|
||||
4096,
|
||||
)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.is_draft_worker:
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
else:
|
||||
self.server_args.draft_runner_cache_size = (
|
||||
self.max_total_num_tokens
|
||||
+ max_num_reqs * self.server_args.speculative_num_steps
|
||||
+ 100
|
||||
)
|
||||
|
||||
if max_total_tokens is not None:
|
||||
if max_total_tokens > self.max_total_num_tokens:
|
||||
logging.warning(
|
||||
@@ -520,17 +551,6 @@ class ModelRunner:
|
||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||
)
|
||||
|
||||
if max_num_reqs is None:
|
||||
max_num_reqs = min(
|
||||
max(
|
||||
int(
|
||||
self.max_total_num_tokens / self.model_config.context_len * 512
|
||||
),
|
||||
2048,
|
||||
),
|
||||
4096,
|
||||
)
|
||||
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs + 1,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
@@ -650,10 +670,6 @@ class ModelRunner:
|
||||
tensor_parallel(self.model, device_mesh)
|
||||
|
||||
def forward_decode(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
||||
self.attn_backend.init_forward_metadata(forward_batch)
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
@@ -683,14 +699,18 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def forward_idle(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||
if (
|
||||
forward_batch.forward_mode.is_cuda_graph()
|
||||
and self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(forward_batch)
|
||||
):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(forward_batch)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
|
||||
Reference in New Issue
Block a user