Remove unused vars in the triton backend (#2401)
This commit is contained in:
@@ -35,11 +35,6 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
|
||||||
self.reduce_dtype = torch.float32
|
|
||||||
else:
|
|
||||||
self.reduce_dtype = torch.float16
|
|
||||||
|
|
||||||
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||||
|
|
||||||
@@ -53,9 +48,6 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
|
||||||
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
|
||||||
|
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
(
|
(
|
||||||
forward_batch.batch_size,
|
forward_batch.batch_size,
|
||||||
@@ -67,13 +59,12 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
|
||||||
max_extend_len = None
|
max_extend_len = None
|
||||||
else:
|
else:
|
||||||
start_loc = attn_logits = max_seq_len = None
|
attn_logits = None
|
||||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||||
|
|
||||||
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
self.forward_metadata = attn_logits, max_extend_len
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||||
@@ -96,9 +87,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
# NOTE: encoder_lens expected to be zeros or None
|
# NOTE: encoder_lens expected to be zeros or None
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
self.cuda_graph_start_loc,
|
|
||||||
self.cuda_graph_attn_logits,
|
self.cuda_graph_attn_logits,
|
||||||
self.cuda_graph_max_seq_len,
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -137,7 +126,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
_, max_extend_len = self.forward_metadata
|
||||||
self.extend_attention_fwd(
|
self.extend_attention_fwd(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
k.contiguous(),
|
k.contiguous(),
|
||||||
@@ -175,7 +164,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
attn_logits, _ = self.forward_metadata
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
@@ -189,10 +178,8 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
forward_batch.req_to_token_pool.req_to_token,
|
forward_batch.req_to_token_pool.req_to_token,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
start_loc,
|
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
max_seq_len,
|
|
||||||
self.num_kv_splits,
|
self.num_kv_splits,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ It supports page size = 1.
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
||||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
@@ -26,6 +29,13 @@ from sglang.srt.utils import is_hip
|
|||||||
|
|
||||||
is_hip_ = is_hip()
|
is_hip_ = is_hip()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
|
||||||
|
logger.warn(
|
||||||
|
"The following error message 'operation scheduled before its operands' can be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def tanh(x):
|
def tanh(x):
|
||||||
@@ -166,7 +176,6 @@ def _decode_att_m_fwd(
|
|||||||
Req_to_tokens,
|
Req_to_tokens,
|
||||||
B_req_idx,
|
B_req_idx,
|
||||||
B_Seqlen,
|
B_Seqlen,
|
||||||
max_len_in_batch,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
@@ -389,7 +398,6 @@ def _decode_grouped_att_m_fwd(
|
|||||||
Req_to_tokens,
|
Req_to_tokens,
|
||||||
B_req_idx,
|
B_req_idx,
|
||||||
B_Seqlen,
|
B_Seqlen,
|
||||||
max_len_in_batch,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
@@ -556,7 +564,6 @@ def decode_attention_fwd_normal(
|
|||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
max_len_in_batch,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
@@ -569,7 +576,6 @@ def decode_attention_fwd_normal(
|
|||||||
req_to_token,
|
req_to_token,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
max_len_in_batch,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
@@ -586,7 +592,6 @@ def decode_attention_fwd_grouped(
|
|||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
max_len_in_batch,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
@@ -599,7 +604,6 @@ def decode_attention_fwd_grouped(
|
|||||||
req_to_token,
|
req_to_token,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
max_len_in_batch,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
@@ -614,10 +618,8 @@ def decode_attention_fwd(
|
|||||||
o,
|
o,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
max_len_in_batch,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
@@ -636,7 +638,6 @@ def decode_attention_fwd(
|
|||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
max_len_in_batch,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
@@ -652,7 +653,6 @@ def decode_attention_fwd(
|
|||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
max_len_in_batch,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
|
|||||||
@@ -196,7 +196,6 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
|
|
||||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||||
b_req_idx = torch.arange(B, device="cuda")
|
b_req_idx = torch.arange(B, device="cuda")
|
||||||
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
|
||||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||||
|
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
@@ -212,10 +211,8 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
o,
|
o,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
seq_len,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
)
|
)
|
||||||
@@ -255,7 +252,6 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
|
|
||||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||||
b_req_idx = torch.arange(B, device="cuda")
|
b_req_idx = torch.arange(B, device="cuda")
|
||||||
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
|
||||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||||
|
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
@@ -273,7 +269,6 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
seq_len,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
)
|
)
|
||||||
@@ -293,7 +288,6 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits1,
|
attn_logits1,
|
||||||
seq_len,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user