Add unit test on page_size > 1 and mla and integration test for Flash Attention 3 (#4760)
This commit is contained in:
@@ -548,8 +548,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# Use Flash Attention for prefill
|
||||
if not self.use_mla:
|
||||
# Do multi-head attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
key_cache = key_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||
)
|
||||
@@ -592,7 +593,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
c_kv_cache = c_kv.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||
)
|
||||
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
@@ -659,8 +659,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
if not self.use_mla:
|
||||
# Do multi-head attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
|
||||
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
key_cache = key_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||
)
|
||||
|
||||
@@ -63,10 +63,6 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
# Base quantization methods that don't depend on vllm
|
||||
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
@@ -176,6 +172,13 @@ def get_linear_quant_method(
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
# Move import here to avoid circular import. This is only used in monkey patching
|
||||
# of vllm's QuantizationConfig.
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = (
|
||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
|
||||
Reference in New Issue
Block a user