Clean up wrapper in flashinfer backend (#2638)
This commit is contained in:
@@ -331,6 +331,7 @@ def throughput_test(
|
||||
extra_request_body=extra_request_body,
|
||||
profile=bench_args.profile,
|
||||
)
|
||||
backend.shutdown()
|
||||
|
||||
if bench_args.result_filename:
|
||||
with open(bench_args.result_filename, "a") as fout:
|
||||
|
||||
@@ -131,10 +131,8 @@ class ModelConfig:
|
||||
# Veirfy quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Text attrs
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||
|
||||
# Multimodel attrs
|
||||
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
|
||||
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
@@ -8,8 +8,9 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -38,12 +39,25 @@ class WrapperDispatch(Enum):
|
||||
CROSS_ATTENTION = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeMetadata:
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefillMetadata:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
||||
use_ragged: bool
|
||||
extend_no_prefix: bool
|
||||
|
||||
|
||||
class FlashInferAttnBackend(AttentionBackend):
|
||||
"""Flashinfer attention kernels."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
|
||||
# Parse constants
|
||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||
num_attention_heads=model_runner.model_config.num_attention_heads
|
||||
@@ -52,7 +66,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
model_runner.tp_size
|
||||
),
|
||||
)
|
||||
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
|
||||
assert not (
|
||||
@@ -120,8 +133,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
# Other metadata
|
||||
self.forward_metadata = None
|
||||
self.cuda_graph_metadata = {}
|
||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
@@ -129,10 +142,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
decode_wrappers=None,
|
||||
decode_wrappers=self.decode_wrappers,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
)
|
||||
self.forward_metadata = (self.decode_wrappers,)
|
||||
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
@@ -149,11 +162,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens,
|
||||
prefill_wrappers=self.prefill_wrappers_paged,
|
||||
use_ragged=use_ragged,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
)
|
||||
|
||||
self.forward_metadata = (use_ragged, extend_no_prefix)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
cuda_graph_kv_indices = torch.zeros(
|
||||
@@ -194,8 +209,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
decode_wrappers=decode_wrappers,
|
||||
encoder_lens=encoder_lens,
|
||||
)
|
||||
self.cuda_graph_metadata[bs] = decode_wrappers
|
||||
self.forward_metadata = (decode_wrappers,)
|
||||
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
||||
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
@@ -209,7 +224,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
decode_wrappers=self.cuda_graph_metadata[bs],
|
||||
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
)
|
||||
|
||||
@@ -225,18 +240,16 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
|
||||
use_ragged, extend_no_prefix = self.forward_metadata
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
|
||||
if not use_ragged:
|
||||
if not self.forward_metadata.use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
@@ -260,7 +273,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
)
|
||||
|
||||
if extend_no_prefix:
|
||||
if self.forward_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
@@ -287,7 +300,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
||||
decode_wrapper = self.forward_metadata.decode_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
@@ -322,7 +337,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
class FlashInferIndicesUpdaterDecode:
|
||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||
# Constants
|
||||
# Parse Constants
|
||||
self.num_qo_heads = (
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
)
|
||||
@@ -340,9 +355,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.decode_wrappers = attn_backend.decode_wrappers
|
||||
|
||||
# Dispatch
|
||||
# Dispatch the update function
|
||||
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
self.update = self.update_sliding_window
|
||||
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||
@@ -356,7 +370,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
@@ -367,7 +381,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
@@ -385,11 +399,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
# Sliding window attention
|
||||
@@ -419,11 +431,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
# Normal attention
|
||||
@@ -446,7 +456,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
|
||||
def call_begin_forward(
|
||||
self,
|
||||
wrapper,
|
||||
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
@@ -486,7 +496,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
|
||||
class FlashInferIndicesUpdaterPrefill:
|
||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||
# Constants
|
||||
# Parse Constants
|
||||
self.num_qo_heads = (
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
)
|
||||
@@ -505,10 +515,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.qo_indptr = attn_backend.qo_indptr
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
||||
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||
|
||||
# Dispatch
|
||||
# Dispatch the update function
|
||||
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
self.update = self.update_sliding_window
|
||||
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||
@@ -523,6 +532,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
@@ -535,6 +545,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
@@ -546,8 +557,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
paged_kernel_lens_sum = seq_lens_sum
|
||||
|
||||
self.call_begin_forward(
|
||||
self.wrapper_ragged,
|
||||
self.wrappers_paged[0],
|
||||
self.prefill_wrapper_ragged,
|
||||
prefill_wrappers[0],
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
@@ -565,6 +576,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
@@ -584,8 +596,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_start_idx = seq_lens - paged_kernel_lens
|
||||
|
||||
self.call_begin_forward(
|
||||
self.wrapper_ragged,
|
||||
self.wrappers_paged[wrapper_id],
|
||||
self.prefill_wrapper_ragged,
|
||||
prefill_wrappers[wrapper_id],
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
@@ -603,6 +615,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
@@ -619,8 +632,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
||||
|
||||
self.call_begin_forward(
|
||||
self.wrapper_ragged,
|
||||
self.wrappers_paged[wrapper_id],
|
||||
self.prefill_wrapper_ragged,
|
||||
prefill_wrappers[wrapper_id],
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
@@ -634,8 +647,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
|
||||
def call_begin_forward(
|
||||
self,
|
||||
wrapper_ragged,
|
||||
wrapper_paged,
|
||||
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
||||
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
|
||||
@@ -24,7 +24,11 @@ from vllm.distributed import (
|
||||
)
|
||||
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -46,6 +50,10 @@ class LogitsProcessorOutput:
|
||||
output_top_logprobs_val: List = None
|
||||
output_top_logprobs_idx: List = None
|
||||
|
||||
# Used by speculative decoding (EAGLE)
|
||||
# The output of transformer layers
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LogitsMetadata:
|
||||
@@ -61,6 +69,8 @@ class LogitsMetadata:
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
||||
|
||||
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||
extend_logprob_pruned_lens_cpu = None
|
||||
@@ -78,6 +88,11 @@ class LogitsMetadata:
|
||||
else:
|
||||
return_top_logprob = False
|
||||
|
||||
if forward_batch.spec_info:
|
||||
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
||||
else:
|
||||
capture_hidden_mode = CaptureHiddenMode.NULL
|
||||
|
||||
return cls(
|
||||
forward_mode=forward_batch.forward_mode,
|
||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||
@@ -87,6 +102,7 @@ class LogitsMetadata:
|
||||
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
||||
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
||||
capture_hidden_mode=capture_hidden_mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -116,7 +132,10 @@ class LogitsProcessor(nn.Module):
|
||||
assert isinstance(logits_metadata, LogitsMetadata)
|
||||
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
if logits_metadata.forward_mode.is_decode():
|
||||
if (
|
||||
logits_metadata.forward_mode.is_decode()
|
||||
or logits_metadata.forward_mode.is_target_verify()
|
||||
):
|
||||
last_index = None
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
@@ -137,6 +156,15 @@ class LogitsProcessor(nn.Module):
|
||||
if not logits_metadata.return_logprob:
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
hidden_states=(
|
||||
hidden_states
|
||||
if logits_metadata.capture_hidden_mode.is_full()
|
||||
else (
|
||||
last_hidden
|
||||
if logits_metadata.capture_hidden_mode.is_last()
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||
|
||||
@@ -843,8 +843,8 @@ class ScheduleBatch:
|
||||
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
||||
self.extend_logprob_start_lens.extend([0] * running_bs)
|
||||
|
||||
def check_decode_mem(self):
|
||||
bs = len(self.reqs)
|
||||
def check_decode_mem(self, buf_multiplier=1):
|
||||
bs = len(self.reqs) * buf_multiplier
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
return True
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test retract decode
|
||||
# Test retract decode for debugging purposes
|
||||
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
|
||||
|
||||
@@ -129,12 +129,12 @@ class Scheduler:
|
||||
)
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
# Directly send to the tokenizer/api
|
||||
# Directly send to the TokenizerManager
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
||||
)
|
||||
else:
|
||||
# Send to the detokenizer
|
||||
# Send to the DetokenizerManager
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
||||
)
|
||||
@@ -385,7 +385,8 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
if self.server_args.enable_dp_attention:
|
||||
|
||||
if self.server_args.enable_dp_attention: # TODO: simplify this
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
@@ -394,7 +395,7 @@ class Scheduler:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
else:
|
||||
# Self-check and re-init some states when the server is idle
|
||||
# When the server is idle, so self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -411,12 +412,13 @@ class Scheduler:
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||
# Create a dummy first batch to start the pipeline for overlap scheduler.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
@@ -426,19 +428,21 @@ class Scheduler:
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# Self-check and re-init some states when the server is idle
|
||||
# When the server is idle, so self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
def recv_requests(self):
|
||||
def recv_requests(self) -> List[Req]:
|
||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
recv_reqs = []
|
||||
|
||||
@@ -812,6 +816,8 @@ class Scheduler:
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
self.batch_is_full = True
|
||||
break
|
||||
if self.server_args.prefill_only_one_req:
|
||||
break
|
||||
|
||||
# Update waiting queue
|
||||
can_run_list = adder.can_run_list
|
||||
@@ -1528,18 +1534,20 @@ def run_scheduler_process(
|
||||
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
||||
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
||||
|
||||
# Configue the logger
|
||||
if dp_rank is None:
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
else:
|
||||
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
||||
suppress_other_loggers()
|
||||
|
||||
# set cpu affinity to this gpu process
|
||||
# Set cpu affinity to this gpu process
|
||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||
|
||||
suppress_other_loggers()
|
||||
parent_process = psutil.Process().parent()
|
||||
|
||||
# Create a scheduler and run the event loop
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||
pipe_writer.send(
|
||||
|
||||
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
@@ -59,6 +60,11 @@ class ForwardMode(IntEnum):
|
||||
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
||||
IDLE = auto()
|
||||
|
||||
# Used in speculative decoding: verify a batch in the target model.
|
||||
TARGET_VERIFY = auto()
|
||||
# Used in speculative decoding: extend a batch in the draft model.
|
||||
DRAFT_EXTEND = auto()
|
||||
|
||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
||||
DUMMY_FIRST = auto()
|
||||
@@ -67,7 +73,12 @@ class ForwardMode(IntEnum):
|
||||
return self == ForwardMode.PREFILL
|
||||
|
||||
def is_extend(self):
|
||||
return self == ForwardMode.EXTEND or self == ForwardMode.MIXED
|
||||
return (
|
||||
self == ForwardMode.EXTEND
|
||||
or self == ForwardMode.MIXED
|
||||
or self == ForwardMode.DRAFT_EXTEND
|
||||
or self == self.TARGET_VERIFY
|
||||
)
|
||||
|
||||
def is_decode(self):
|
||||
return self == ForwardMode.DECODE
|
||||
@@ -78,6 +89,15 @@ class ForwardMode(IntEnum):
|
||||
def is_idle(self):
|
||||
return self == ForwardMode.IDLE
|
||||
|
||||
def is_target_verify(self):
|
||||
return self == ForwardMode.TARGET_VERIFY
|
||||
|
||||
def is_draft_extend(self):
|
||||
return self == ForwardMode.DRAFT_EXTEND
|
||||
|
||||
def is_cuda_graph(self):
|
||||
return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
|
||||
|
||||
def is_dummy_first(self):
|
||||
return self == ForwardMode.DUMMY_FIRST
|
||||
|
||||
@@ -141,14 +161,18 @@ class ForwardBatch:
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
attn_backend: AttentionBackend = None
|
||||
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
# Speculative decoding
|
||||
spec_info: SpecInfo = None
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
def compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
@@ -351,3 +375,18 @@ def compute_position_torch(
|
||||
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||
return positions.to(torch.int64), extend_start_loc
|
||||
|
||||
|
||||
class CaptureHiddenMode(IntEnum):
|
||||
NULL = auto()
|
||||
FULL = auto()
|
||||
LAST = auto()
|
||||
|
||||
def need_capture(self):
|
||||
return self != CaptureHiddenMode.NULL
|
||||
|
||||
def is_full(self):
|
||||
return self == CaptureHiddenMode.FULL
|
||||
|
||||
def is_last(self):
|
||||
return self == CaptureHiddenMode.LAST
|
||||
|
||||
@@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module):
|
||||
)
|
||||
return None
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
def set_embed_and_head(self, embed, head):
|
||||
del self.model.embed_tokens.weight
|
||||
del self.lm_head.weight
|
||||
self.model.embed_tokens.weight = embed
|
||||
self.lm_head.weight = head
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
@@ -503,7 +503,7 @@ def launch_engine(
|
||||
)
|
||||
scheduler_infos.append(data)
|
||||
|
||||
# Assume all schedulers have same max_total_num_tokens
|
||||
# Assume all schedulers have same scheduler_info
|
||||
scheduler_info = scheduler_infos[0]
|
||||
|
||||
|
||||
@@ -890,7 +890,7 @@ class Runtime:
|
||||
using the commond line interface.
|
||||
|
||||
It is mainly used for the frontend language.
|
||||
You should use the Engine class if you want to do normal offline processing.
|
||||
You should use the Engine class above if you want to do normal offline processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -55,7 +55,7 @@ class ServerArgs:
|
||||
is_embedding: bool = False
|
||||
revision: Optional[str] = None
|
||||
|
||||
# Port
|
||||
# Port for the HTTP server
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
@@ -68,6 +68,7 @@ class ServerArgs:
|
||||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
prefill_only_one_req: bool = False
|
||||
|
||||
# Other runtime options
|
||||
tp_size: int = 1
|
||||
@@ -94,6 +95,7 @@ class ServerArgs:
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
|
||||
@@ -217,6 +219,13 @@ class ServerArgs:
|
||||
)
|
||||
self.disable_cuda_graph = True
|
||||
|
||||
# Expert parallelism
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
logger.info(
|
||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
# Others
|
||||
if self.enable_dp_attention:
|
||||
self.dp_size = self.tp_size
|
||||
@@ -229,12 +238,6 @@ class ServerArgs:
|
||||
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
||||
"Overlap scheduler is disabled."
|
||||
)
|
||||
# Expert parallelism
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
logger.info(
|
||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
# GGUF
|
||||
if (
|
||||
@@ -430,13 +433,18 @@ class ServerArgs:
|
||||
default=ServerArgs.schedule_conservativeness,
|
||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu-offload-gb",
|
||||
type=int,
|
||||
default=ServerArgs.cpu_offload_gb,
|
||||
help="How many GBs of RAM to reserve for CPU offloading",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-only-one-req",
|
||||
type=bool,
|
||||
help="If true, we only prefill one request at one prefill batch",
|
||||
default=ServerArgs.prefill_only_one_req,
|
||||
)
|
||||
|
||||
# Other runtime options
|
||||
parser.add_argument(
|
||||
@@ -555,6 +563,7 @@ class ServerArgs:
|
||||
"shortest_queue",
|
||||
],
|
||||
)
|
||||
|
||||
# Expert parallelism
|
||||
parser.add_argument(
|
||||
"--expert-parallel-size",
|
||||
@@ -777,28 +786,6 @@ class ServerArgs:
|
||||
help="Delete the model checkpoint after loading the model.",
|
||||
)
|
||||
|
||||
# Deprecated arguments
|
||||
parser.add_argument(
|
||||
"--enable-overlap-schedule",
|
||||
action=DeprecatedAction,
|
||||
help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer",
|
||||
action=DeprecatedAction,
|
||||
help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer-sampling",
|
||||
action=DeprecatedAction,
|
||||
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-disk-cache",
|
||||
action=DeprecatedAction,
|
||||
help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
|
||||
19
python/sglang/srt/speculative/spec_info.py
Normal file
19
python/sglang/srt/speculative/spec_info.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from enum import IntEnum, auto
|
||||
|
||||
|
||||
class SpeculativeAlgorithm(IntEnum):
|
||||
EAGLE = auto()
|
||||
|
||||
def is_eagle(self):
|
||||
return self == SpeculativeAlgorithm.EAGLE
|
||||
|
||||
@staticmethod
|
||||
def from_string(name: str):
|
||||
name_map = {
|
||||
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||
}
|
||||
return name_map[name]
|
||||
|
||||
|
||||
class SpecInfo:
|
||||
pass
|
||||
Reference in New Issue
Block a user