diff --git a/python/sglang/srt/layers/attention/torch_native_backend.py b/python/sglang/srt/layers/attention/torch_native_backend.py index 95237a595..78ed042de 100644 --- a/python/sglang/srt/layers/attention/torch_native_backend.py +++ b/python/sglang/srt/layers/attention/torch_native_backend.py @@ -6,6 +6,7 @@ import torch from torch.nn.functional import scaled_dot_product_attention from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch if TYPE_CHECKING: @@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend): q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) + causal = True + if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: + causal = False + self._run_sdpa_forward_extend( q_, o_, @@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend): forward_batch.extend_seq_lens, scaling=layer.scaling, enable_gqa=use_gqa, - causal=not layer.is_cross_attention, + causal=causal, ) return o diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 29547ed43..0aa3a695e 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -10,6 +10,7 @@ import triton.language as tl from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import get_bool_env_var, get_device_core_count @@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend): layer, forward_batch.out_cache_loc, k, v ) + causal = True + if layer.attn_type == AttentionType.ENCODER_ONLY: + causal = False + self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend): self.forward_metadata.kv_indptr, self.forward_metadata.kv_indices, self.forward_metadata.custom_mask, + causal, self.forward_metadata.mask_indptr, self.forward_metadata.max_extend_len, layer.scaling, diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 17e45599d..f6c0173da 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -74,6 +74,7 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, USE_CUSTOM_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, SKIP_PREFIX_CUSTOM_MASK: tl.constexpr, STORE_TRANSPOSE: tl.constexpr, ): @@ -129,6 +130,7 @@ def _fwd_kernel( for start_n in range(0, cur_seq_len_prefix, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) mask_n = (start_n + offs_n) < cur_seq_len_prefix + offs_kv_loc = tl.load( kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0 ) @@ -196,7 +198,11 @@ def _fwd_kernel( # stage 2: compute the triangle part - cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) + cur_block_m_end = ( + cur_seq_len_extend + if not IS_CAUSAL + else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) + ) for start_n in range(0, cur_block_m_end, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) mask_n = (start_n + offs_n) < cur_block_m_end @@ -243,12 +249,15 @@ def _fwd_kernel( ) custom_mask &= mask_m[:, None] & mask_n[None, :] qk = tl.where(custom_mask, qk, float("-inf")) - else: + elif IS_CAUSAL: mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( start_n + offs_n[None, :] ) mask_causual &= mask_m[:, None] & mask_n[None, :] qk = tl.where(mask_causual, qk, float("-inf")) + else: + mask_non_causal = mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_non_causal, qk, float("-inf")) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -299,6 +308,7 @@ def extend_attention_fwd( kv_indptr, kv_indices, custom_mask, + is_causal, mask_indptr, max_len_extend, sm_scale=None, @@ -411,6 +421,7 @@ def extend_attention_fwd( Lq=Lq, Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, + IS_CAUSAL=is_causal, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, STORE_TRANSPOSE=_is_hip, num_warps=num_warps, diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 3bb30bc15..3c10a3924 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -13,6 +13,7 @@ # ============================================================================== """Radix attention.""" +from enum import Enum from typing import Optional from torch import nn @@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch +class AttentionType(Enum): + """ + Attention type. + Use string to be compatible with `torch.compile`. + """ + + # Decoder attention between previous layer Q/K/V + DECODER = "decoder" + # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = "encoder_only" + + class RadixAttention(nn.Module): """ The attention layer implementation. @@ -39,6 +52,7 @@ class RadixAttention(nn.Module): sliding_window_size: int = -1, is_cross_attention: bool = False, quant_config: Optional[QuantizationConfig] = None, + attn_type=AttentionType.DECODER, prefix: str = "", use_irope: bool = False, ): @@ -64,6 +78,7 @@ class RadixAttention(nn.Module): self.quant_method = quant_config.get_quant_method(self, prefix=prefix) if self.quant_method is not None: self.quant_method.create_weights(self) + self.attn_type = attn_type def forward( self, diff --git a/python/sglang/srt/models/bert.py b/python/sglang/srt/models/bert.py new file mode 100644 index 000000000..46d2e7265 --- /dev/null +++ b/python/sglang/srt/models/bert.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, Iterable, Optional, Set, Tuple + +import torch +from torch import nn + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import AttentionType, RadixAttention +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader + +BertConfig = None + + +class BertEmbedding(nn.Module): + + def __init__(self, config: BertConfig): + + super().__init__() + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.position_embeddings = VocabParallelEmbedding( + config.max_position_embeddings, config.hidden_size + ) + self.token_type_embeddings = VocabParallelEmbedding( + config.type_vocab_size, config.hidden_size + ) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.position_ids = nn.Parameter( + torch.empty((1, config.max_position_embeddings)), + ) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError( + "Only 'absolute' position_embedding_type" + " is supported" + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + input_shape = input_ids.size() + + # Input embeddings. + inputs_embeds = self.word_embeddings(input_ids) + + # Position embeddings. + position_embeddings = self.position_embeddings(position_ids) + + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=inputs_embeds.device + ) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertEncoder(nn.Module): + + def __init__( + self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.layer = nn.ModuleList( + [ + BertLayer( + config=config, + layer_id=layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.layer.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + def forward( + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + ) -> torch.Tensor: + for layer in self.layer: + hidden_states = layer(hidden_states, forward_batch) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__( + self, + config: BertConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.attention = BertAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + layer_id=layer_id, + layer_norm_eps=config.layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.attention", + ) + + self.intermediate = BertIntermediate( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.intermediate", + ) + + self.output = BertOutput( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_norm_eps=config.layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output", + ) + + def forward(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch): + attn_output = self.attention(hidden_states, forward_batch) + intermediate_output = self.intermediate(attn_output) + output = self.output(intermediate_output, attn_output) + return output + + +class BertAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + layer_norm_eps: float, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.self_attn = BertSelfAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + layer_id=layer_id, + quant_config=quant_config, + prefix=f"{prefix}.output", + ) + + self.output = BertSelfOutput( + hidden_size=hidden_size, + layer_norm_eps=layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output", + ) + + def forward( + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + ) -> torch.Tensor: + self_output = self.self_attn(hidden_states, forward_batch) + return self.output(self_output, hidden_states) + + +class BertSelfAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_attention_heads + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.attn = RadixAttention( + num_heads=self.num_heads, + head_dim=self.head_dim, + scaling=self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_ONLY, + ) + + def forward( + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + output = self.attn(q, k, v, forward_batch) + return output + + +class BertSelfOutput(nn.Module): + + def __init__( + self, + hidden_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.dense = RowParallelLinear( + input_size=hidden_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertIntermediate(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.dense = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + self.intermediate_act_fn = get_act_fn(hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.dense = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertModel(nn.Module): + + def __init__( + self, + *, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embeddings = BertEmbedding(config) + self.encoder = BertEncoder( + config=config, quant_config=quant_config, prefix=f"encoder" + ) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # self.pooler = BertPooler(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> torch.Tensor: + assert get_embedding == True + # Your tokenized IDs + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=positions, + ) + + hidden_states = self.encoder(hidden_states, forward_batch=forward_batch) + return self.pooler(hidden_states, forward_batch) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + name = name.replace("self", "self_attn") + if "pooler" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +class Contriever(BertModel): + pass + + +EntryClass = [BertModel, Contriever] diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index a7eaebdb2..dc70e8a16 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5 def get_dtype_str(torch_dtype): if torch_dtype is torch.float16: return "float16" + if torch_dtype is torch.float32: + return "float32" else: raise NotImplementedError() @@ -447,6 +449,7 @@ class SRTRunner: port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, lora_paths: List[str] = None, max_loras_per_batch: int = 4, + attention_backend: Optional[str] = None, lora_backend: str = "triton", disable_cuda_graph: bool = False, disable_radix_cache: bool = False, @@ -487,6 +490,7 @@ class SRTRunner: lora_paths=lora_paths, max_loras_per_batch=max_loras_per_batch, lora_backend=lora_backend, + attention_backend=attention_backend, disable_cuda_graph=disable_cuda_graph, disable_radix_cache=disable_radix_cache, chunked_prefill_size=chunked_prefill_size, diff --git a/test/srt/models/test_encoder_embedding_models.py b/test/srt/models/test_encoder_embedding_models.py new file mode 100644 index 000000000..4dad0be15 --- /dev/null +++ b/test/srt/models/test_encoder_embedding_models.py @@ -0,0 +1,149 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# python -m unittest test_encoder_embedding_models.TestEncoderEmbeddingModels.test_prefill_logits + +import multiprocessing as mp +import random +import time +import unittest + +import torch +from transformers import AutoConfig, AutoTokenizer + +from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner +from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci + +MODELS = [("BAAI/bge-small-en", 1, 1e-5)] + +ATTENTION_BACKEND = ["torch_native", "triton"] +BATCH_SIZE = [30] +TORCH_DTYPES = [torch.float32] +sgl_to_st_ratio = [] + + +class TestEncoderEmbeddingModels(CustomTestCase): + + @classmethod + def setUpClass(cls): + mp.set_start_method("spawn", force=True) + + def _truncate_prompts(self, prompts, model_path): + config = AutoConfig.from_pretrained(model_path) + max_length = getattr(config, "max_position_embeddings", 512) - 20 + + tokenizer = AutoTokenizer.from_pretrained(model_path) + + truncated_prompts = [] + for prompt in prompts: + tokens = tokenizer(prompt, return_tensors="pt", truncation=False) + if len(tokens.input_ids[0]) > max_length: + truncated_text = tokenizer.decode( + tokens.input_ids[0][: max_length - 1], skip_special_tokens=True + ) + truncated_prompts.append(truncated_text) + else: + truncated_prompts.append(prompt) + + return truncated_prompts + + def assert_close_prefill_logits( + self, + prompts, + model_path, + tp_size, + torch_dtype, + prefill_tolerance, + attention_backend, + batch_size, + ) -> None: + truncated_prompts = self._truncate_prompts(prompts, model_path) + truncated_prompts = truncated_prompts * batch_size + + with HFRunner( + model_path, + torch_dtype=torch_dtype, + model_type="embedding", + ) as hf_runner: + # warm up + hf_outputs = hf_runner.forward(truncated_prompts) + + st_start_time = time.time() + hf_outputs = hf_runner.forward(truncated_prompts) + st_end_time = time.time() + + with SRTRunner( + model_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + model_type="embedding", + attention_backend=attention_backend, + chunked_prefill_size=-1, + disable_radix_cache=True, + ) as srt_runner: + # warm up + srt_outputs = srt_runner.forward(truncated_prompts) + + sgl_start_time = time.time() + srt_outputs = srt_runner.forward(truncated_prompts) + sgl_end_time = time.time() + + transformer_time = st_end_time - st_start_time + sgl_time = sgl_end_time - sgl_start_time + sgl_to_st_ratio.append(sgl_time / transformer_time) + + for i in range(len(truncated_prompts)): + hf_logits = torch.Tensor(hf_outputs.embed_logits[i]) + srt_logits = torch.Tensor(srt_outputs.embed_logits[i]) + + similarity = torch.tensor(get_similarities(hf_logits, srt_logits)) + # If something is wrong, uncomment this to observe similarity. + # print("similarity diff", abs(similarity - 1)) + + if len(truncated_prompts[i]) <= 1000: + assert torch.all( + abs(similarity - 1) < prefill_tolerance + ), "embeddings are not all close" + + def test_prefill_logits(self): + models_to_test = MODELS + + if is_in_ci(): + models_to_test = [random.choice(MODELS)] + + for model, tp_size, prefill_tolerance in models_to_test: + for attention_backend in ATTENTION_BACKEND: + for batch_size in BATCH_SIZE: + for torch_dtype in TORCH_DTYPES: + self.assert_close_prefill_logits( + DEFAULT_PROMPTS, + model, + tp_size, + torch_dtype, + prefill_tolerance, + attention_backend, + batch_size, + ) + + for i in range(len(BATCH_SIZE)): + print( + "bacth size: ", + BATCH_SIZE[i] * 5, + "sgl_time/st_time", + round(sgl_to_st_ratio[i], 3), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 184733e7f..47eb16a9b 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -116,6 +116,7 @@ class TestTritonAttention(CustomTestCase): kv_indptr, kv_indices, custom_mask, + True, mask_indptr, max_len_extend, ) @@ -150,6 +151,7 @@ class TestTritonAttention(CustomTestCase): kv_indptr, kv_indices, custom_mask, + True, mask_indptr, max_len_extend, )