Fix grammar abort & Minor style fixes (#7204)
This commit is contained in:
@@ -15,7 +15,6 @@ from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||
import logging
|
||||
@@ -33,7 +32,7 @@ from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -756,7 +755,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
|
||||
if topk > 1:
|
||||
raise ValueError(
|
||||
f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
|
||||
"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
|
||||
)
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
@@ -815,9 +814,9 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
self.pool_len,
|
||||
kv_indices_buffer.shape[1],
|
||||
self.kv_indptr.shape[1],
|
||||
triton.next_power_of_2(num_seqs),
|
||||
triton.next_power_of_2(self.speculative_num_steps),
|
||||
triton.next_power_of_2(bs),
|
||||
next_power_of_2(num_seqs),
|
||||
next_power_of_2(self.speculative_num_steps),
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
|
||||
assert forward_batch.spec_info is not None
|
||||
|
||||
@@ -464,11 +464,9 @@ class FlashMLAMultiStepDraftBackend:
|
||||
topk: int,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||
|
||||
if topk > 1:
|
||||
raise ValueError(
|
||||
f"Currently FlashMLA only supports topk=1 for speculative decoding"
|
||||
"Currently FlashMLA only supports topk=1 for speculative decoding"
|
||||
)
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
|
||||
@@ -12,7 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -766,6 +766,7 @@ class TritonMultiStepDraftBackend:
|
||||
self.device = model_runner.device
|
||||
# Cached variables for generate_draft_decode_kv_indices
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
self.page_size = model_runner.server_args.page_size
|
||||
|
||||
def common_template(
|
||||
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
||||
@@ -788,9 +789,9 @@ class TritonMultiStepDraftBackend:
|
||||
self.pool_len,
|
||||
kv_indices_buffer.shape[1],
|
||||
self.kv_indptr.shape[1],
|
||||
triton.next_power_of_2(num_seqs),
|
||||
triton.next_power_of_2(self.speculative_num_steps),
|
||||
triton.next_power_of_2(bs),
|
||||
next_power_of_2(num_seqs),
|
||||
next_power_of_2(self.speculative_num_steps),
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
|
||||
@@ -708,7 +708,7 @@ def decode_attention_fwd(
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
@@ -724,5 +724,5 @@ def decode_attention_fwd(
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user