backport of #7474
This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.
**Key changes:**
1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.
2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.
3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.
4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.
5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.
Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.
No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.
Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):
```
============ Serving Benchmark Result ============
Successful requests: 960
Failed requests: 0
Maximum request concurrency: 192
Benchmark duration (s): 1359.81
Total input tokens: 9830400
Total generated tokens: 1966080
Request throughput (req/s): 0.71
Output token throughput (tok/s): 1445.85
Peak output token throughput (tok/s): 2304.00
Total token throughput (tok/s): 8675.12
---------------Time to First Token----------------
Mean TTFT (ms): 24598.51
Median TTFT (ms): 23167.02
P50 TTFT (ms): 23167.02
P90 TTFT (ms): 47717.08
P99 TTFT (ms): 84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 120.76
Median TPOT (ms): 121.50
P50 TPOT (ms): 121.50
P90 TPOT (ms): 127.05
P99 TPOT (ms): 130.13
---------------Inter-token Latency----------------
Mean ITL (ms): 120.70
Median ITL (ms): 90.34
P50 ITL (ms): 90.34
P90 ITL (ms): 93.79
P99 ITL (ms): 101.80
==================================================
```
All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import ( # type: ignore
|
||||
AttentionBackend,
|
||||
@@ -978,3 +979,364 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
|
||||
|
||||
class AscendC8AttentionBackendImpl(AscendAttentionBackendImpl):
|
||||
"""Attention backend implementation for INT8 KV cache (C8/QuaRot) models.
|
||||
|
||||
This subclass handles static per-channel INT8 KV cache quantization.
|
||||
It is activated via class surgery in AscendC8KVCacheAttentionMethod.create_weights
|
||||
(vllm_ascend/quantization/methods/kv_c8.py)
|
||||
so that C8 attention layers automatically use this forward path.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError("fused output quantization is not yet supported for AscendC8AttentionBackendImpl")
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.fill_(0)
|
||||
|
||||
float_key, float_value = None, None
|
||||
if key is not None and value is not None:
|
||||
if attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
|
||||
float_key, float_value = key, value
|
||||
key, value = self._quantize_kv_to_int8(key, value, layer, attn_metadata.num_actual_tokens)
|
||||
query, key, value, _ = self.reshape_and_cache(query, key, value, kv_cache, attn_metadata, output)
|
||||
|
||||
if attn_metadata.model_runner_type == "pooling":
|
||||
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
|
||||
self._prepare_c8_scales(layer, query.device)
|
||||
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
return self._forward_c8_decode(query, attn_metadata, output, layer)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
return self._forward_c8_chunked_prefill(query, float_key, float_value, attn_metadata, output, layer)
|
||||
else:
|
||||
return self._forward_c8_fused_infer_attention(
|
||||
query,
|
||||
float_key if float_key is not None else key,
|
||||
float_value if float_value is not None else value,
|
||||
attn_metadata,
|
||||
output,
|
||||
layer,
|
||||
)
|
||||
|
||||
def _prepare_c8_scales(self, layer: AttentionLayer, device: torch.device) -> None:
|
||||
"""Shard per-channel C8 scales/offsets to this TP rank and pre-compute
|
||||
BF16 BNSD antiquant tensors for FIA V1 decode fast path.
|
||||
"""
|
||||
if hasattr(layer, "_c8_scales_prepared"):
|
||||
return
|
||||
|
||||
def _shard_and_reshape(raw: torch.Tensor) -> torch.Tensor:
|
||||
if raw.numel() == 1:
|
||||
return raw.to(device=device)
|
||||
expected = self.num_kv_heads * self.head_size
|
||||
if raw.numel() != expected:
|
||||
total_kv_heads = raw.numel() // self.head_size
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
kv_head_start = tp_rank * total_kv_heads // tp_size
|
||||
raw = raw.view(total_kv_heads, self.head_size)[
|
||||
kv_head_start : kv_head_start + self.num_kv_heads
|
||||
].contiguous()
|
||||
return raw.view(1, self.num_kv_heads, self.head_size).to(device=device)
|
||||
|
||||
layer._c8_k_scale = _shard_and_reshape(layer.k_cache_scale.data)
|
||||
layer._c8_k_offset = _shard_and_reshape(layer.k_cache_offset.data)
|
||||
layer._c8_v_scale = _shard_and_reshape(layer.v_cache_scale.data)
|
||||
layer._c8_v_offset = _shard_and_reshape(layer.v_cache_offset.data)
|
||||
|
||||
bnsd = (1, self.num_kv_heads, 1, self.head_size)
|
||||
layer._c8_k_aq_scale = layer._c8_k_scale.to(torch.bfloat16).view(bnsd).contiguous()
|
||||
layer._c8_k_aq_offset = layer._c8_k_offset.to(torch.bfloat16).view(bnsd).contiguous()
|
||||
layer._c8_v_aq_scale = layer._c8_v_scale.to(torch.bfloat16).view(bnsd).contiguous()
|
||||
layer._c8_v_aq_offset = layer._c8_v_offset.to(torch.bfloat16).view(bnsd).contiguous()
|
||||
|
||||
layer._c8_k_inv_scale_bf16 = (1.0 / layer._c8_k_scale).to(torch.bfloat16)
|
||||
layer._c8_k_offset_bf16 = layer._c8_k_offset.to(torch.bfloat16)
|
||||
layer._c8_v_inv_scale_bf16 = (1.0 / layer._c8_v_scale).to(torch.bfloat16)
|
||||
layer._c8_v_offset_bf16 = layer._c8_v_offset.to(torch.bfloat16)
|
||||
|
||||
layer._c8_scales_prepared = True
|
||||
|
||||
def _dequant_paged_kv_to_dense(
|
||||
self,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: list,
|
||||
target_dtype: torch.dtype,
|
||||
layer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Gather paged INT8 KV blocks and dequantize to target_dtype."""
|
||||
batch_size = block_table.shape[0]
|
||||
block_size = key.shape[1]
|
||||
H = key.shape[2]
|
||||
max_blocks_per_seq = block_table.shape[1]
|
||||
max_tokens_padded = max_blocks_per_seq * block_size
|
||||
|
||||
flat_ids = block_table.reshape(-1)
|
||||
gathered_k = key[flat_ids].view(batch_size, max_tokens_padded, H)
|
||||
gathered_v = value[flat_ids].view(batch_size, max_tokens_padded, H)
|
||||
|
||||
seq_lens_t = torch.tensor(seq_lens, dtype=torch.long, device=key.device)
|
||||
positions = torch.arange(max_tokens_padded, dtype=torch.long, device=key.device)
|
||||
valid_mask = (positions.unsqueeze(0) < seq_lens_t.unsqueeze(1)).view(-1)
|
||||
|
||||
dense_k = gathered_k.view(-1, H)[valid_mask]
|
||||
dense_v = gathered_v.view(-1, H)[valid_mask]
|
||||
|
||||
dense_k = dense_k.view(-1, self.num_kv_heads, self.head_size)
|
||||
dense_v = dense_v.view(-1, self.num_kv_heads, self.head_size)
|
||||
k_scale = layer._c8_k_scale.to(target_dtype)
|
||||
k_offset = layer._c8_k_offset.to(target_dtype)
|
||||
v_scale = layer._c8_v_scale.to(target_dtype)
|
||||
v_offset = layer._c8_v_offset.to(target_dtype)
|
||||
dense_k = (dense_k.to(target_dtype) - k_offset) * k_scale
|
||||
dense_v = (dense_v.to(target_dtype) - v_offset) * v_scale
|
||||
return dense_k, dense_v
|
||||
|
||||
def _quantize_kv_to_int8(
|
||||
self,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer: AttentionLayer,
|
||||
num_actual_tokens: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize K/V from float to INT8 using static per-channel C8 scales."""
|
||||
self._prepare_c8_scales(layer, key.device)
|
||||
|
||||
actual_key = key[:num_actual_tokens]
|
||||
actual_value = value[:num_actual_tokens]
|
||||
|
||||
k_int8 = torch.clamp(
|
||||
torch.round(actual_key * layer._c8_k_inv_scale_bf16 + layer._c8_k_offset_bf16),
|
||||
-128,
|
||||
127,
|
||||
).to(torch.int8)
|
||||
v_int8 = torch.clamp(
|
||||
torch.round(actual_value * layer._c8_v_inv_scale_bf16 + layer._c8_v_offset_bf16),
|
||||
-128,
|
||||
127,
|
||||
).to(torch.int8)
|
||||
return k_int8, v_int8
|
||||
|
||||
def _forward_c8_decode(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
"""C8 decode via FIA V1 BNSD with native paged INT8 KV + perchannel antiquant."""
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
|
||||
assert block_size % 32 == 0, f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
|
||||
key = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
|
||||
value = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
|
||||
batch_size = len(attn_metadata.seq_lens_list)
|
||||
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query[:batch_size].unsqueeze(2),
|
||||
key,
|
||||
value,
|
||||
key_antiquant_scale=layer._c8_k_aq_scale,
|
||||
key_antiquant_offset=layer._c8_k_aq_offset,
|
||||
value_antiquant_scale=layer._c8_v_aq_scale,
|
||||
value_antiquant_offset=layer._c8_v_aq_offset,
|
||||
block_table=attn_metadata.block_tables,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BNSD",
|
||||
scale=self.scale,
|
||||
block_size=block_size,
|
||||
key_antiquant_mode=0,
|
||||
value_antiquant_mode=0,
|
||||
sparse_mode=0,
|
||||
)
|
||||
attn_output = attn_output.squeeze(2)
|
||||
output[:batch_size] = attn_output
|
||||
return output
|
||||
|
||||
def _forward_c8_chunked_prefill(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
float_key: torch.Tensor | None,
|
||||
float_value: torch.Tensor | None,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
"""C8 ChunkedPrefill: decode via FIA V1 BNSD paged INT8 (zero gather),
|
||||
prefill via FIA V1 TND with float KV (new) or gather+dequant (continuing).
|
||||
"""
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
|
||||
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
|
||||
|
||||
if num_decode_tokens > 0:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
|
||||
assert block_size % 32 == 0, (
|
||||
f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
|
||||
)
|
||||
kv_k = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
|
||||
kv_v = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
|
||||
|
||||
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query[:num_decode_tokens].unsqueeze(2),
|
||||
kv_k,
|
||||
kv_v,
|
||||
key_antiquant_scale=layer._c8_k_aq_scale,
|
||||
key_antiquant_offset=layer._c8_k_aq_offset,
|
||||
value_antiquant_scale=layer._c8_v_aq_scale,
|
||||
value_antiquant_offset=layer._c8_v_aq_offset,
|
||||
block_table=attn_metadata.block_tables[:num_decodes],
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens_list[:num_decodes],
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BNSD",
|
||||
scale=self.scale,
|
||||
block_size=block_size,
|
||||
key_antiquant_mode=0,
|
||||
value_antiquant_mode=0,
|
||||
sparse_mode=0,
|
||||
)
|
||||
output[:num_decode_tokens] = attn_out.squeeze(2)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
prefill_q = query[num_decode_tokens:num_tokens]
|
||||
|
||||
prefill_seq_qlen = [
|
||||
actual_seq_qlen[i] - num_decode_tokens for i in range(num_decodes, len(actual_seq_qlen))
|
||||
]
|
||||
|
||||
all_new_prefill = True
|
||||
for i in range(num_decodes, len(attn_metadata.seq_lens_list)):
|
||||
q_start = actual_seq_qlen[i - 1] if i > 0 else 0
|
||||
qlen_i = actual_seq_qlen[i] - q_start
|
||||
if attn_metadata.seq_lens_list[i] > qlen_i:
|
||||
all_new_prefill = False
|
||||
break
|
||||
|
||||
if all_new_prefill and float_key is not None and float_value is not None:
|
||||
prefill_k = float_key[num_decode_tokens:num_tokens]
|
||||
prefill_v = float_value[num_decode_tokens:num_tokens]
|
||||
prefill_seq_kvlen = prefill_seq_qlen
|
||||
else:
|
||||
num_block, blk_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
|
||||
paged_k = self.key_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
|
||||
paged_v = self.value_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
|
||||
prefill_bt = attn_metadata.block_tables[num_decodes:]
|
||||
prefill_sl = attn_metadata.seq_lens_list[num_decodes:]
|
||||
prefill_k, prefill_v = self._dequant_paged_kv_to_dense(
|
||||
paged_k, paged_v, prefill_bt, prefill_sl, query.dtype, layer
|
||||
)
|
||||
prefill_seq_kvlen = torch.tensor(prefill_sl, dtype=torch.int32).cumsum(dim=0)
|
||||
|
||||
# block_table is None for prefill; FIA ignores block_size in this case.
|
||||
# Use cache block_size for consistency rather than a magic number.
|
||||
cache_block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
|
||||
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=prefill_q,
|
||||
key=prefill_k,
|
||||
value=prefill_v,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
block_table=None,
|
||||
input_layout="TND",
|
||||
block_size=cache_block_size,
|
||||
actual_seq_lengths=prefill_seq_qlen,
|
||||
actual_seq_lengths_kv=prefill_seq_kvlen,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
n_prefill = num_tokens - num_decode_tokens
|
||||
attn_out = attn_out.view(n_prefill, self.num_heads, self.head_size)
|
||||
output[num_decode_tokens:num_tokens] = attn_out[:n_prefill]
|
||||
|
||||
return output
|
||||
|
||||
def _forward_c8_fused_infer_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
layer: AttentionLayer,
|
||||
):
|
||||
"""C8 FIA V1 TND for prefill states (PrefillNoCache uses float KV directly,
|
||||
PrefillCacheHit gathers + dequants paged INT8 KV).
|
||||
"""
|
||||
self._prepare_c8_scales(layer, query.device)
|
||||
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
||||
|
||||
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
|
||||
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
|
||||
query = query[:num_tokens]
|
||||
|
||||
if (
|
||||
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
|
||||
and self.attn_type != AttentionType.ENCODER_DECODER
|
||||
):
|
||||
key = key[:num_tokens]
|
||||
value = value[:num_tokens]
|
||||
|
||||
if key.dtype == torch.int8:
|
||||
if block_table is not None:
|
||||
seq_lens = (
|
||||
actual_seq_lengths_kv if isinstance(actual_seq_lengths_kv, list) else actual_seq_lengths_kv.tolist()
|
||||
)
|
||||
key, value = self._dequant_paged_kv_to_dense(key, value, block_table, seq_lens, query.dtype, layer)
|
||||
block_table = None
|
||||
# block_table is None after dequant; FIA ignores block_size.
|
||||
# Use cache block_size for consistency rather than a magic number.
|
||||
block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
|
||||
actual_seq_lengths_kv = torch.tensor(seq_lens, dtype=torch.int32).cumsum(dim=0)
|
||||
else:
|
||||
qdt = query.dtype
|
||||
k_scale = layer._c8_k_scale.to(qdt)
|
||||
k_offset = layer._c8_k_offset.to(qdt)
|
||||
v_scale = layer._c8_v_scale.to(qdt)
|
||||
v_offset = layer._c8_v_offset.to(qdt)
|
||||
key = (key.to(qdt) - k_offset) * k_scale
|
||||
value = (value.to(qdt) - v_offset) * v_scale
|
||||
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=actual_seq_qlen,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
|
||||
output[:num_tokens] = attn_output
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user