Fix grammar abort & Minor style fixes (#7204)
This commit is contained in:
@@ -15,7 +15,6 @@ from functools import partial
|
|||||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
|
|
||||||
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||||
import logging
|
import logging
|
||||||
@@ -33,7 +32,7 @@ from sglang.srt.layers.utils import is_sm100_supported
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -756,7 +755,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|||||||
|
|
||||||
if topk > 1:
|
if topk > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
|
"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
|
||||||
)
|
)
|
||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
@@ -815,9 +814,9 @@ class FlashInferMLAMultiStepDraftBackend:
|
|||||||
self.pool_len,
|
self.pool_len,
|
||||||
kv_indices_buffer.shape[1],
|
kv_indices_buffer.shape[1],
|
||||||
self.kv_indptr.shape[1],
|
self.kv_indptr.shape[1],
|
||||||
triton.next_power_of_2(num_seqs),
|
next_power_of_2(num_seqs),
|
||||||
triton.next_power_of_2(self.speculative_num_steps),
|
next_power_of_2(self.speculative_num_steps),
|
||||||
triton.next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert forward_batch.spec_info is not None
|
assert forward_batch.spec_info is not None
|
||||||
|
|||||||
@@ -464,11 +464,9 @@ class FlashMLAMultiStepDraftBackend:
|
|||||||
topk: int,
|
topk: int,
|
||||||
speculative_num_steps: int,
|
speculative_num_steps: int,
|
||||||
):
|
):
|
||||||
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
|
||||||
|
|
||||||
if topk > 1:
|
if topk > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Currently FlashMLA only supports topk=1 for speculative decoding"
|
"Currently FlashMLA only supports topk=1 for speculative decoding"
|
||||||
)
|
)
|
||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.layers.radix_attention import AttentionType
|
from sglang.srt.layers.radix_attention import AttentionType
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -766,6 +766,7 @@ class TritonMultiStepDraftBackend:
|
|||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
# Cached variables for generate_draft_decode_kv_indices
|
# Cached variables for generate_draft_decode_kv_indices
|
||||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||||
|
self.page_size = model_runner.server_args.page_size
|
||||||
|
|
||||||
def common_template(
|
def common_template(
|
||||||
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
||||||
@@ -788,9 +789,9 @@ class TritonMultiStepDraftBackend:
|
|||||||
self.pool_len,
|
self.pool_len,
|
||||||
kv_indices_buffer.shape[1],
|
kv_indices_buffer.shape[1],
|
||||||
self.kv_indptr.shape[1],
|
self.kv_indptr.shape[1],
|
||||||
triton.next_power_of_2(num_seqs),
|
next_power_of_2(num_seqs),
|
||||||
triton.next_power_of_2(self.speculative_num_steps),
|
next_power_of_2(self.speculative_num_steps),
|
||||||
triton.next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
|
|||||||
@@ -708,7 +708,7 @@ def decode_attention_fwd(
|
|||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap=logit_cap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# GQA/MQA/MLA
|
# GQA/MQA/MLA
|
||||||
@@ -724,5 +724,5 @@ def decode_attention_fwd(
|
|||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap=logit_cap,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from typing import Optional
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
@@ -52,9 +51,9 @@ class RadixAttention(nn.Module):
|
|||||||
sliding_window_size: int = -1,
|
sliding_window_size: int = -1,
|
||||||
is_cross_attention: bool = False,
|
is_cross_attention: bool = False,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
attn_type=AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
prefix: str = "",
|
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_q_head_num = num_heads
|
self.tp_q_head_num = num_heads
|
||||||
|
|||||||
@@ -2108,7 +2108,8 @@ class Scheduler(
|
|||||||
# In this case, we change the input_ids to be only one token to make this prefill cheap.
|
# In this case, we change the input_ids to be only one token to make this prefill cheap.
|
||||||
if req.rid.startswith(recv_req.rid):
|
if req.rid.startswith(recv_req.rid):
|
||||||
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
||||||
req.grammar.cancel()
|
if req.grammar:
|
||||||
|
req.grammar.cancel()
|
||||||
req.set_finish_with_abort("Aborted by AbortReq.")
|
req.set_finish_with_abort("Aborted by AbortReq.")
|
||||||
|
|
||||||
# Delete requests in the running batch
|
# Delete requests in the running batch
|
||||||
|
|||||||
@@ -141,15 +141,12 @@ class KVCache(abc.ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_flat_data(self, indices):
|
def get_flat_data(self, indices):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def transfer(self, indices, flat_data):
|
def transfer(self, indices, flat_data):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -86,8 +86,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
|
|
||||||
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||||
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||||
self.accept_length = (
|
self.accept_length = torch.full(
|
||||||
torch.ones((self.max_bs,), dtype=torch.int32) * self.num_tokens_per_bs
|
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
|
|||||||
Reference in New Issue
Block a user