[fix] fix illegal mem access and clean up triton attention backend (#4571)
This commit is contained in:
@@ -26,7 +26,6 @@ import tqdm
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||
@@ -196,9 +195,6 @@ class CudaGraphRunner:
|
||||
# Attention backend
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
||||
self.seq_len_fill_value = (
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
@@ -507,15 +503,9 @@ class CudaGraphRunner:
|
||||
if hasattr(forward_batch.spec_info, "hidden_states"):
|
||||
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
||||
|
||||
num_kv_heads = self.num_head
|
||||
if hasattr(forward_batch.token_to_kv_pool, "k_buffer"):
|
||||
if isinstance(forward_batch.token_to_kv_pool.k_buffer, list):
|
||||
num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1]
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
num_kv_heads,
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||
|
||||
Reference in New Issue
Block a user