Clean up wrapper in flashinfer backend (#2638)
This commit is contained in:
@@ -331,6 +331,7 @@ def throughput_test(
|
|||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
profile=bench_args.profile,
|
profile=bench_args.profile,
|
||||||
)
|
)
|
||||||
|
backend.shutdown()
|
||||||
|
|
||||||
if bench_args.result_filename:
|
if bench_args.result_filename:
|
||||||
with open(bench_args.result_filename, "a") as fout:
|
with open(bench_args.result_filename, "a") as fout:
|
||||||
|
|||||||
@@ -131,10 +131,8 @@ class ModelConfig:
|
|||||||
# Veirfy quantization
|
# Veirfy quantization
|
||||||
self._verify_quantization()
|
self._verify_quantization()
|
||||||
|
|
||||||
# Text attrs
|
# Cache attributes
|
||||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||||
|
|
||||||
# Multimodel attrs
|
|
||||||
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -38,12 +39,25 @@ class WrapperDispatch(Enum):
|
|||||||
CROSS_ATTENTION = auto()
|
CROSS_ATTENTION = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecodeMetadata:
|
||||||
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PrefillMetadata:
|
||||||
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
||||||
|
use_ragged: bool
|
||||||
|
extend_no_prefix: bool
|
||||||
|
|
||||||
|
|
||||||
class FlashInferAttnBackend(AttentionBackend):
|
class FlashInferAttnBackend(AttentionBackend):
|
||||||
"""Flashinfer attention kernels."""
|
"""Flashinfer attention kernels."""
|
||||||
|
|
||||||
def __init__(self, model_runner: ModelRunner):
|
def __init__(self, model_runner: ModelRunner):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# Parse constants
|
||||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||||
num_attention_heads=model_runner.model_config.num_attention_heads
|
num_attention_heads=model_runner.model_config.num_attention_heads
|
||||||
@@ -52,7 +66,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
model_runner.tp_size
|
model_runner.tp_size
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
@@ -120,8 +133,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Other metadata
|
# Other metadata
|
||||||
self.forward_metadata = None
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||||
self.cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
@@ -129,10 +142,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
decode_wrappers=None,
|
decode_wrappers=self.decode_wrappers,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
)
|
)
|
||||||
self.forward_metadata = (self.decode_wrappers,)
|
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
||||||
else:
|
else:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
|
||||||
@@ -149,11 +162,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
|
prefill_wrappers=self.prefill_wrappers_paged,
|
||||||
use_ragged=use_ragged,
|
use_ragged=use_ragged,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
)
|
)
|
||||||
|
self.forward_metadata = PrefillMetadata(
|
||||||
self.forward_metadata = (use_ragged, extend_no_prefix)
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
||||||
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
cuda_graph_kv_indices = torch.zeros(
|
cuda_graph_kv_indices = torch.zeros(
|
||||||
@@ -194,8 +209,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
decode_wrappers=decode_wrappers,
|
decode_wrappers=decode_wrappers,
|
||||||
encoder_lens=encoder_lens,
|
encoder_lens=encoder_lens,
|
||||||
)
|
)
|
||||||
self.cuda_graph_metadata[bs] = decode_wrappers
|
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
||||||
self.forward_metadata = (decode_wrappers,)
|
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -209,7 +224,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices[:bs],
|
req_pool_indices[:bs],
|
||||||
seq_lens[:bs],
|
seq_lens[:bs],
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
decode_wrappers=self.cuda_graph_metadata[bs],
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
||||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -225,18 +240,16 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||||
self._get_wrapper_idx(layer)
|
self._get_wrapper_idx(layer)
|
||||||
]
|
]
|
||||||
|
|
||||||
use_ragged, extend_no_prefix = self.forward_metadata
|
|
||||||
cache_loc = (
|
cache_loc = (
|
||||||
forward_batch.out_cache_loc
|
forward_batch.out_cache_loc
|
||||||
if not layer.is_cross_attention
|
if not layer.is_cross_attention
|
||||||
else forward_batch.encoder_out_cache_loc
|
else forward_batch.encoder_out_cache_loc
|
||||||
)
|
)
|
||||||
|
|
||||||
if not use_ragged:
|
if not self.forward_metadata.use_ragged:
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
@@ -260,7 +273,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=layer.logit_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
if extend_no_prefix:
|
if self.forward_metadata.extend_no_prefix:
|
||||||
o = o1
|
o = o1
|
||||||
else:
|
else:
|
||||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||||
@@ -287,7 +300,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
decode_wrapper = self.forward_metadata.decode_wrappers[
|
||||||
|
self._get_wrapper_idx(layer)
|
||||||
|
]
|
||||||
cache_loc = (
|
cache_loc = (
|
||||||
forward_batch.out_cache_loc
|
forward_batch.out_cache_loc
|
||||||
if not layer.is_cross_attention
|
if not layer.is_cross_attention
|
||||||
@@ -322,7 +337,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
class FlashInferIndicesUpdaterDecode:
|
class FlashInferIndicesUpdaterDecode:
|
||||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||||
# Constants
|
# Parse Constants
|
||||||
self.num_qo_heads = (
|
self.num_qo_heads = (
|
||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
)
|
)
|
||||||
@@ -340,9 +355,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.kv_indptr = attn_backend.kv_indptr
|
self.kv_indptr = attn_backend.kv_indptr
|
||||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
self.decode_wrappers = attn_backend.decode_wrappers
|
|
||||||
|
|
||||||
# Dispatch
|
# Dispatch the update function
|
||||||
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
self.update = self.update_sliding_window
|
self.update = self.update_sliding_window
|
||||||
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||||
@@ -356,7 +370,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List,
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
@@ -367,7 +381,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List,
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
@@ -385,11 +399,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List,
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
|
||||||
|
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
# Sliding window attention
|
# Sliding window attention
|
||||||
@@ -419,11 +431,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List,
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
|
||||||
|
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
# Normal attention
|
# Normal attention
|
||||||
@@ -446,7 +456,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self,
|
self,
|
||||||
wrapper,
|
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
@@ -486,7 +496,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
|
|
||||||
class FlashInferIndicesUpdaterPrefill:
|
class FlashInferIndicesUpdaterPrefill:
|
||||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||||
# Constants
|
# Parse Constants
|
||||||
self.num_qo_heads = (
|
self.num_qo_heads = (
|
||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
)
|
)
|
||||||
@@ -505,10 +515,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||||
self.qo_indptr = attn_backend.qo_indptr
|
self.qo_indptr = attn_backend.qo_indptr
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||||
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
|
||||||
|
|
||||||
# Dispatch
|
# Dispatch the update function
|
||||||
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
self.update = self.update_sliding_window
|
self.update = self.update_sliding_window
|
||||||
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||||
@@ -523,6 +532,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
@@ -535,6 +545,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
@@ -546,8 +557,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
paged_kernel_lens_sum = seq_lens_sum
|
paged_kernel_lens_sum = seq_lens_sum
|
||||||
|
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
self.wrapper_ragged,
|
self.prefill_wrapper_ragged,
|
||||||
self.wrappers_paged[0],
|
prefill_wrappers[0],
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
paged_kernel_lens_sum,
|
paged_kernel_lens_sum,
|
||||||
@@ -565,6 +576,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
@@ -584,8 +596,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
kv_start_idx = seq_lens - paged_kernel_lens
|
kv_start_idx = seq_lens - paged_kernel_lens
|
||||||
|
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
self.wrapper_ragged,
|
self.prefill_wrapper_ragged,
|
||||||
self.wrappers_paged[wrapper_id],
|
prefill_wrappers[wrapper_id],
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
paged_kernel_lens_sum,
|
paged_kernel_lens_sum,
|
||||||
@@ -603,6 +615,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
@@ -619,8 +632,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
||||||
|
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
self.wrapper_ragged,
|
self.prefill_wrapper_ragged,
|
||||||
self.wrappers_paged[wrapper_id],
|
prefill_wrappers[wrapper_id],
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
paged_kernel_lens_sum,
|
paged_kernel_lens_sum,
|
||||||
@@ -634,8 +647,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self,
|
self,
|
||||||
wrapper_ragged,
|
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
wrapper_paged,
|
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
|
|||||||
@@ -24,7 +24,11 @@ from vllm.distributed import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
CaptureHiddenMode,
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -46,6 +50,10 @@ class LogitsProcessorOutput:
|
|||||||
output_top_logprobs_val: List = None
|
output_top_logprobs_val: List = None
|
||||||
output_top_logprobs_idx: List = None
|
output_top_logprobs_idx: List = None
|
||||||
|
|
||||||
|
# Used by speculative decoding (EAGLE)
|
||||||
|
# The output of transformer layers
|
||||||
|
hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitsMetadata:
|
class LogitsMetadata:
|
||||||
@@ -61,6 +69,8 @@ class LogitsMetadata:
|
|||||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||||
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
||||||
|
|
||||||
|
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||||
extend_logprob_pruned_lens_cpu = None
|
extend_logprob_pruned_lens_cpu = None
|
||||||
@@ -78,6 +88,11 @@ class LogitsMetadata:
|
|||||||
else:
|
else:
|
||||||
return_top_logprob = False
|
return_top_logprob = False
|
||||||
|
|
||||||
|
if forward_batch.spec_info:
|
||||||
|
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
||||||
|
else:
|
||||||
|
capture_hidden_mode = CaptureHiddenMode.NULL
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
forward_mode=forward_batch.forward_mode,
|
forward_mode=forward_batch.forward_mode,
|
||||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||||
@@ -87,6 +102,7 @@ class LogitsMetadata:
|
|||||||
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||||
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
||||||
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
||||||
|
capture_hidden_mode=capture_hidden_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -116,7 +132,10 @@ class LogitsProcessor(nn.Module):
|
|||||||
assert isinstance(logits_metadata, LogitsMetadata)
|
assert isinstance(logits_metadata, LogitsMetadata)
|
||||||
|
|
||||||
# Get the last hidden states and last logits for the next token prediction
|
# Get the last hidden states and last logits for the next token prediction
|
||||||
if logits_metadata.forward_mode.is_decode():
|
if (
|
||||||
|
logits_metadata.forward_mode.is_decode()
|
||||||
|
or logits_metadata.forward_mode.is_target_verify()
|
||||||
|
):
|
||||||
last_index = None
|
last_index = None
|
||||||
last_hidden = hidden_states
|
last_hidden = hidden_states
|
||||||
else:
|
else:
|
||||||
@@ -137,6 +156,15 @@ class LogitsProcessor(nn.Module):
|
|||||||
if not logits_metadata.return_logprob:
|
if not logits_metadata.return_logprob:
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
|
hidden_states=(
|
||||||
|
hidden_states
|
||||||
|
if logits_metadata.capture_hidden_mode.is_full()
|
||||||
|
else (
|
||||||
|
last_hidden
|
||||||
|
if logits_metadata.capture_hidden_mode.is_last()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||||
|
|||||||
@@ -843,8 +843,8 @@ class ScheduleBatch:
|
|||||||
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
||||||
self.extend_logprob_start_lens.extend([0] * running_bs)
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
||||||
|
|
||||||
def check_decode_mem(self):
|
def check_decode_mem(self, buf_multiplier=1):
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs) * buf_multiplier
|
||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ from sglang.utils import get_exception_traceback
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Test retract decode
|
# Test retract decode for debugging purposes
|
||||||
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||||
|
|
||||||
|
|
||||||
@@ -129,12 +129,12 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
# Directly send to the tokenizer/api
|
# Directly send to the TokenizerManager
|
||||||
self.send_to_detokenizer = get_zmq_socket(
|
self.send_to_detokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Send to the detokenizer
|
# Send to the DetokenizerManager
|
||||||
self.send_to_detokenizer = get_zmq_socket(
|
self.send_to_detokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
||||||
)
|
)
|
||||||
@@ -385,7 +385,8 @@ class Scheduler:
|
|||||||
self.process_input_requests(recv_reqs)
|
self.process_input_requests(recv_reqs)
|
||||||
|
|
||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
if self.server_args.enable_dp_attention:
|
|
||||||
|
if self.server_args.enable_dp_attention: # TODO: simplify this
|
||||||
batch = self.prepare_dp_attn_batch(batch)
|
batch = self.prepare_dp_attn_batch(batch)
|
||||||
|
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
@@ -394,7 +395,7 @@ class Scheduler:
|
|||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
else:
|
else:
|
||||||
# Self-check and re-init some states when the server is idle
|
# When the server is idle, so self-check and re-init some states
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
@@ -411,12 +412,13 @@ class Scheduler:
|
|||||||
|
|
||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
result_queue.append((batch.copy(), result))
|
result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
if self.last_batch is None:
|
if self.last_batch is None:
|
||||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
# Create a dummy first batch to start the pipeline for overlap scheduler.
|
||||||
# It is now used for triggering the sampling_info_done event.
|
# It is now used for triggering the sampling_info_done event.
|
||||||
tmp_batch = ScheduleBatch(
|
tmp_batch = ScheduleBatch(
|
||||||
reqs=None,
|
reqs=None,
|
||||||
@@ -426,19 +428,21 @@ class Scheduler:
|
|||||||
self.process_batch_result(tmp_batch, None)
|
self.process_batch_result(tmp_batch, None)
|
||||||
|
|
||||||
if self.last_batch:
|
if self.last_batch:
|
||||||
|
# Process the results of the last batch
|
||||||
tmp_batch, tmp_result = result_queue.popleft()
|
tmp_batch, tmp_result = result_queue.popleft()
|
||||||
tmp_batch.next_batch_sampling_info = (
|
tmp_batch.next_batch_sampling_info = (
|
||||||
self.tp_worker.cur_sampling_info if batch else None
|
self.tp_worker.cur_sampling_info if batch else None
|
||||||
)
|
)
|
||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
elif batch is None:
|
elif batch is None:
|
||||||
# Self-check and re-init some states when the server is idle
|
# When the server is idle, so self-check and re-init some states
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
def recv_requests(self):
|
def recv_requests(self) -> List[Req]:
|
||||||
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||||
recv_reqs = []
|
recv_reqs = []
|
||||||
|
|
||||||
@@ -812,6 +816,8 @@ class Scheduler:
|
|||||||
if res == AddReqResult.NO_TOKEN:
|
if res == AddReqResult.NO_TOKEN:
|
||||||
self.batch_is_full = True
|
self.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
if self.server_args.prefill_only_one_req:
|
||||||
|
break
|
||||||
|
|
||||||
# Update waiting queue
|
# Update waiting queue
|
||||||
can_run_list = adder.can_run_list
|
can_run_list = adder.can_run_list
|
||||||
@@ -1528,18 +1534,20 @@ def run_scheduler_process(
|
|||||||
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
||||||
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
||||||
|
|
||||||
|
# Configue the logger
|
||||||
if dp_rank is None:
|
if dp_rank is None:
|
||||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||||
else:
|
else:
|
||||||
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
||||||
|
suppress_other_loggers()
|
||||||
|
|
||||||
# set cpu affinity to this gpu process
|
# Set cpu affinity to this gpu process
|
||||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||||
|
|
||||||
suppress_other_loggers()
|
|
||||||
parent_process = psutil.Process().parent()
|
parent_process = psutil.Process().parent()
|
||||||
|
|
||||||
|
# Create a scheduler and run the event loop
|
||||||
try:
|
try:
|
||||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||||
pipe_writer.send(
|
pipe_writer.send(
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
|
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
@@ -59,6 +60,11 @@ class ForwardMode(IntEnum):
|
|||||||
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
||||||
IDLE = auto()
|
IDLE = auto()
|
||||||
|
|
||||||
|
# Used in speculative decoding: verify a batch in the target model.
|
||||||
|
TARGET_VERIFY = auto()
|
||||||
|
# Used in speculative decoding: extend a batch in the draft model.
|
||||||
|
DRAFT_EXTEND = auto()
|
||||||
|
|
||||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||||
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
||||||
DUMMY_FIRST = auto()
|
DUMMY_FIRST = auto()
|
||||||
@@ -67,7 +73,12 @@ class ForwardMode(IntEnum):
|
|||||||
return self == ForwardMode.PREFILL
|
return self == ForwardMode.PREFILL
|
||||||
|
|
||||||
def is_extend(self):
|
def is_extend(self):
|
||||||
return self == ForwardMode.EXTEND or self == ForwardMode.MIXED
|
return (
|
||||||
|
self == ForwardMode.EXTEND
|
||||||
|
or self == ForwardMode.MIXED
|
||||||
|
or self == ForwardMode.DRAFT_EXTEND
|
||||||
|
or self == self.TARGET_VERIFY
|
||||||
|
)
|
||||||
|
|
||||||
def is_decode(self):
|
def is_decode(self):
|
||||||
return self == ForwardMode.DECODE
|
return self == ForwardMode.DECODE
|
||||||
@@ -78,6 +89,15 @@ class ForwardMode(IntEnum):
|
|||||||
def is_idle(self):
|
def is_idle(self):
|
||||||
return self == ForwardMode.IDLE
|
return self == ForwardMode.IDLE
|
||||||
|
|
||||||
|
def is_target_verify(self):
|
||||||
|
return self == ForwardMode.TARGET_VERIFY
|
||||||
|
|
||||||
|
def is_draft_extend(self):
|
||||||
|
return self == ForwardMode.DRAFT_EXTEND
|
||||||
|
|
||||||
|
def is_cuda_graph(self):
|
||||||
|
return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
|
||||||
|
|
||||||
def is_dummy_first(self):
|
def is_dummy_first(self):
|
||||||
return self == ForwardMode.DUMMY_FIRST
|
return self == ForwardMode.DUMMY_FIRST
|
||||||
|
|
||||||
@@ -141,14 +161,18 @@ class ForwardBatch:
|
|||||||
token_to_kv_pool: BaseTokenToKVPool = None
|
token_to_kv_pool: BaseTokenToKVPool = None
|
||||||
attn_backend: AttentionBackend = None
|
attn_backend: AttentionBackend = None
|
||||||
|
|
||||||
# For Qwen2-VL
|
# Speculative decoding
|
||||||
mrope_positions: torch.Tensor = None
|
spec_info: SpecInfo = None
|
||||||
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
|
|
||||||
# For DP attention
|
# For DP attention
|
||||||
global_num_tokens: Optional[List[int]] = None
|
global_num_tokens: Optional[List[int]] = None
|
||||||
gathered_buffer: Optional[torch.Tensor] = None
|
gathered_buffer: Optional[torch.Tensor] = None
|
||||||
can_run_dp_cuda_graph: bool = False
|
can_run_dp_cuda_graph: bool = False
|
||||||
|
|
||||||
|
# For Qwen2-VL
|
||||||
|
mrope_positions: torch.Tensor = None
|
||||||
|
|
||||||
def compute_mrope_positions(
|
def compute_mrope_positions(
|
||||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||||
):
|
):
|
||||||
@@ -351,3 +375,18 @@ def compute_position_torch(
|
|||||||
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
||||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||||
return positions.to(torch.int64), extend_start_loc
|
return positions.to(torch.int64), extend_start_loc
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureHiddenMode(IntEnum):
|
||||||
|
NULL = auto()
|
||||||
|
FULL = auto()
|
||||||
|
LAST = auto()
|
||||||
|
|
||||||
|
def need_capture(self):
|
||||||
|
return self != CaptureHiddenMode.NULL
|
||||||
|
|
||||||
|
def is_full(self):
|
||||||
|
return self == CaptureHiddenMode.FULL
|
||||||
|
|
||||||
|
def is_last(self):
|
||||||
|
return self == CaptureHiddenMode.LAST
|
||||||
|
|||||||
@@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_embed_and_head(self):
|
||||||
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|
||||||
|
def set_embed_and_head(self, embed, head):
|
||||||
|
del self.model.embed_tokens.weight
|
||||||
|
del self.lm_head.weight
|
||||||
|
self.model.embed_tokens.weight = embed
|
||||||
|
self.lm_head.weight = head
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -503,7 +503,7 @@ def launch_engine(
|
|||||||
)
|
)
|
||||||
scheduler_infos.append(data)
|
scheduler_infos.append(data)
|
||||||
|
|
||||||
# Assume all schedulers have same max_total_num_tokens
|
# Assume all schedulers have same scheduler_info
|
||||||
scheduler_info = scheduler_infos[0]
|
scheduler_info = scheduler_infos[0]
|
||||||
|
|
||||||
|
|
||||||
@@ -890,7 +890,7 @@ class Runtime:
|
|||||||
using the commond line interface.
|
using the commond line interface.
|
||||||
|
|
||||||
It is mainly used for the frontend language.
|
It is mainly used for the frontend language.
|
||||||
You should use the Engine class if you want to do normal offline processing.
|
You should use the Engine class above if you want to do normal offline processing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class ServerArgs:
|
|||||||
is_embedding: bool = False
|
is_embedding: bool = False
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
|
||||||
# Port
|
# Port for the HTTP server
|
||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
port: int = 30000
|
port: int = 30000
|
||||||
|
|
||||||
@@ -68,6 +68,7 @@ class ServerArgs:
|
|||||||
schedule_policy: str = "lpm"
|
schedule_policy: str = "lpm"
|
||||||
schedule_conservativeness: float = 1.0
|
schedule_conservativeness: float = 1.0
|
||||||
cpu_offload_gb: int = 0
|
cpu_offload_gb: int = 0
|
||||||
|
prefill_only_one_req: bool = False
|
||||||
|
|
||||||
# Other runtime options
|
# Other runtime options
|
||||||
tp_size: int = 1
|
tp_size: int = 1
|
||||||
@@ -94,6 +95,7 @@ class ServerArgs:
|
|||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
load_balance_method: str = "round_robin"
|
load_balance_method: str = "round_robin"
|
||||||
|
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
ep_size: int = 1
|
ep_size: int = 1
|
||||||
|
|
||||||
@@ -217,6 +219,13 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.disable_cuda_graph = True
|
self.disable_cuda_graph = True
|
||||||
|
|
||||||
|
# Expert parallelism
|
||||||
|
if self.enable_ep_moe:
|
||||||
|
self.ep_size = self.tp_size
|
||||||
|
logger.info(
|
||||||
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||||
|
)
|
||||||
|
|
||||||
# Others
|
# Others
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
self.dp_size = self.tp_size
|
self.dp_size = self.tp_size
|
||||||
@@ -229,12 +238,6 @@ class ServerArgs:
|
|||||||
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
||||||
"Overlap scheduler is disabled."
|
"Overlap scheduler is disabled."
|
||||||
)
|
)
|
||||||
# Expert parallelism
|
|
||||||
if self.enable_ep_moe:
|
|
||||||
self.ep_size = self.tp_size
|
|
||||||
logger.info(
|
|
||||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
|
||||||
)
|
|
||||||
|
|
||||||
# GGUF
|
# GGUF
|
||||||
if (
|
if (
|
||||||
@@ -430,13 +433,18 @@ class ServerArgs:
|
|||||||
default=ServerArgs.schedule_conservativeness,
|
default=ServerArgs.schedule_conservativeness,
|
||||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cpu-offload-gb",
|
"--cpu-offload-gb",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.cpu_offload_gb,
|
default=ServerArgs.cpu_offload_gb,
|
||||||
help="How many GBs of RAM to reserve for CPU offloading",
|
help="How many GBs of RAM to reserve for CPU offloading",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-only-one-req",
|
||||||
|
type=bool,
|
||||||
|
help="If true, we only prefill one request at one prefill batch",
|
||||||
|
default=ServerArgs.prefill_only_one_req,
|
||||||
|
)
|
||||||
|
|
||||||
# Other runtime options
|
# Other runtime options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -555,6 +563,7 @@ class ServerArgs:
|
|||||||
"shortest_queue",
|
"shortest_queue",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--expert-parallel-size",
|
"--expert-parallel-size",
|
||||||
@@ -777,28 +786,6 @@ class ServerArgs:
|
|||||||
help="Delete the model checkpoint after loading the model.",
|
help="Delete the model checkpoint after loading the model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Deprecated arguments
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-overlap-schedule",
|
|
||||||
action=DeprecatedAction,
|
|
||||||
help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-flashinfer",
|
|
||||||
action=DeprecatedAction,
|
|
||||||
help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-flashinfer-sampling",
|
|
||||||
action=DeprecatedAction,
|
|
||||||
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-disk-cache",
|
|
||||||
action=DeprecatedAction,
|
|
||||||
help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
|
|||||||
19
python/sglang/srt/speculative/spec_info.py
Normal file
19
python/sglang/srt/speculative/spec_info.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from enum import IntEnum, auto
|
||||||
|
|
||||||
|
|
||||||
|
class SpeculativeAlgorithm(IntEnum):
|
||||||
|
EAGLE = auto()
|
||||||
|
|
||||||
|
def is_eagle(self):
|
||||||
|
return self == SpeculativeAlgorithm.EAGLE
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_string(name: str):
|
||||||
|
name_map = {
|
||||||
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||||
|
}
|
||||||
|
return name_map[name]
|
||||||
|
|
||||||
|
|
||||||
|
class SpecInfo:
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user