550 lines
21 KiB
Python
550 lines
21 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from dataclasses import dataclass
|
|
from typing import ClassVar
|
|
|
|
import torch
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.deep_gemm import (
|
|
get_paged_mqa_logits_metadata,
|
|
is_deep_gemm_supported,
|
|
)
|
|
from vllm.utils.math_utils import cdiv
|
|
from vllm.utils.platform_utils import num_compute_units
|
|
from vllm.v1.attention.backend import (
|
|
AttentionBackend,
|
|
AttentionCGSupport,
|
|
AttentionMetadataBuilder,
|
|
CommonAttentionMetadata,
|
|
MultipleOf,
|
|
)
|
|
from vllm.v1.attention.backends.utils import (
|
|
split_decodes_and_prefills,
|
|
split_prefill_chunks,
|
|
)
|
|
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class DeepseekV32IndexerBackend(AttentionBackend):
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "DEEPSEEK_V32_INDEXER"
|
|
|
|
@staticmethod
|
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
|
return [1 if current_platform.is_rocm() else 64]
|
|
|
|
@classmethod
|
|
def get_supported_head_sizes(cls) -> list[int]:
|
|
return [32, 64, 128]
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
|
|
return DeepseekV32IndexerMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> tuple[int, ...]:
|
|
assert num_kv_heads == 1
|
|
return (num_blocks, block_size, head_size)
|
|
|
|
@staticmethod
|
|
def get_kv_cache_stride_order(
|
|
include_num_layers_dimension: bool = False,
|
|
) -> tuple[int, ...]:
|
|
if include_num_layers_dimension:
|
|
return (0, 1, 2, 3)
|
|
return (0, 1, 2)
|
|
|
|
|
|
@dataclass
|
|
class DeepseekV32IndexerPrefillChunkMetadata:
|
|
block_table: torch.Tensor
|
|
cu_seqlen_ks: torch.Tensor
|
|
cu_seqlen_ke: torch.Tensor
|
|
cu_seq_lens: torch.Tensor
|
|
cu_seqlens_q: torch.Tensor
|
|
token_to_seq: torch.Tensor
|
|
total_seq_lens: int
|
|
token_start: int
|
|
token_end: int
|
|
num_reqs: int
|
|
max_context_len: int
|
|
max_q_len: int # Maximum query length for dsa_indexer_mqa_logits_with_blocks
|
|
max_kv_len: int # Maximum key-value length for dsa_indexer_mqa_logits_with_blocks
|
|
|
|
|
|
@dataclass
|
|
class DeepseekV32IndexerPrefillMetadata:
|
|
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
|
|
|
|
|
@dataclass
|
|
class DeepSeekV32IndexerDecodeMetadata:
|
|
block_table: torch.Tensor
|
|
seq_lens: torch.Tensor
|
|
decode_lens: torch.Tensor
|
|
requires_padding: bool
|
|
# schedule_metadata: torch.Tensor
|
|
use_large_context_topk: bool
|
|
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
|
|
cu_seqlen_ks: torch.Tensor
|
|
cu_seqlen_ke: torch.Tensor
|
|
cu_seqlens_kv: torch.Tensor
|
|
cu_seqlens_q: torch.Tensor
|
|
max_context_len: int
|
|
max_q_len: int
|
|
max_kv_len: int
|
|
|
|
|
|
@dataclass
|
|
class DeepseekV32IndexerMetadata:
|
|
# FIXME (zyongye)
|
|
# hacky way to access the data now, need to be in chunked meta
|
|
seq_lens: torch.Tensor
|
|
|
|
num_reqs: int
|
|
max_query_len: int
|
|
max_seq_len: int
|
|
|
|
num_actual_tokens: int # Number of tokens excluding padding.
|
|
query_start_loc: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
# The dimension of the attention heads
|
|
head_dim: int
|
|
|
|
# New for MLA (compared to FlashAttention)
|
|
# For handling prefill decode split
|
|
num_decodes: int
|
|
num_decode_tokens: int
|
|
num_prefills: int
|
|
num_prefill_tokens: int
|
|
|
|
decode: DeepSeekV32IndexerDecodeMetadata | None = None
|
|
prefill: DeepseekV32IndexerPrefillMetadata | None = None
|
|
|
|
|
|
# TODO (zyongye) optimize this, this is now vibe coded
|
|
def kv_spans_from_batches(
|
|
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
start_seq_loc: 1D long tensor [B+1], cumulative counts of
|
|
selected tokens per batch.
|
|
Example: [0, 2, 4, 7] ->
|
|
batch sizes (selected) [2, 2, 3], N=7 tokens total.
|
|
seq_len_per_batch: 1D long tensor [B],
|
|
full sequence length (KV length) of each batch.
|
|
Example: [5, 9, 4].
|
|
|
|
Returns:
|
|
start_tensor: 1D long tensor [N], start offset in the
|
|
concatenated KV cache for each token's batch.
|
|
end_location: 1D long tensor [N],
|
|
**exclusive** end = start + token's local position.
|
|
(So the attended KV slice is kv[start:end].)
|
|
|
|
Assumes each batch contributes its full `seq_len_per_batch[i]`
|
|
keys to the KV cache, andthe selected tokens within a batch
|
|
are the **last** `counts[i]` positions of that sequence.
|
|
"""
|
|
q = start_seq_loc.to(dtype=torch.long)
|
|
L = seq_len_per_batch.to(dtype=torch.long)
|
|
assert q.dim() == 1 and L.dim() == 1
|
|
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
|
|
|
|
# Selected tokens per batch and totals
|
|
counts = q[1:] - q[:-1] # [B]
|
|
N = int(q[-1].item()) # total selected tokens
|
|
B = L.numel()
|
|
|
|
if N == 0:
|
|
return (
|
|
torch.empty(0, dtype=torch.long, device=device),
|
|
torch.empty(0, dtype=torch.long, device=device),
|
|
)
|
|
|
|
# KV start offsets per batch in the concatenated KV cache
|
|
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
|
|
|
|
# For each selected token, which batch does it belong to?
|
|
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
|
|
|
|
# Map batch KV start to each token
|
|
start_tensor = kv_starts_per_batch[batch_id] # [N]
|
|
|
|
# End-align local positions inside each batch:
|
|
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
|
|
L_expand = torch.repeat_interleave(L, counts) # [N]
|
|
m_expand = torch.repeat_interleave(counts, counts) # [N]
|
|
# position within the selected block: 1..counts[b]
|
|
pos_within = (
|
|
torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1
|
|
)
|
|
|
|
local_pos = L_expand - m_expand + pos_within # [N], 1-based
|
|
end_location = start_tensor + local_pos # exclusive end
|
|
|
|
return start_tensor.int().to(device), end_location.int().to(device)
|
|
|
|
|
|
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
|
max_model_len = vllm_config.model_config.max_model_len
|
|
# NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
|
|
# Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes.
|
|
# The flashmla_sparse backend uses a workspace size of 5 * max_model_len.
|
|
# The memory usage of the workspace there is 576 * 2 bytes; so we size this as
|
|
# (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting
|
|
# within the flashmla_sparse workspace.
|
|
# For DeepSeek-V3.2, the max_model_len is 163840.
|
|
# 40 * 163840 * 132 = 865075200 bytes = 825 MB
|
|
return max_model_len * 40
|
|
|
|
|
|
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
|
|
|
reorder_batch_threshold: int = 1
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
scheduler_config = self.vllm_config.scheduler_config
|
|
# NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
|
|
self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config)
|
|
self.num_speculative_tokens = (
|
|
self.vllm_config.speculative_config.num_speculative_tokens
|
|
if self.vllm_config.speculative_config
|
|
else 0
|
|
)
|
|
self.reorder_batch_threshold += self.num_speculative_tokens
|
|
|
|
sm_count = num_compute_units(self.device.index)
|
|
self.num_sms = sm_count
|
|
|
|
self.decode_lens_buffer = torch.empty(
|
|
(scheduler_config.max_num_batched_tokens,),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
|
|
# Pre-allocated buffers for flattening (spec decode).
|
|
self.arange_buffer = torch.arange(
|
|
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
self.expanded_seq_lens_buffer = torch.zeros(
|
|
(scheduler_config.max_num_batched_tokens,),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
max_num_blocks_per_req = cdiv(
|
|
self.vllm_config.model_config.max_model_len,
|
|
self.kv_cache_spec.block_size * get_total_cp_world_size(),
|
|
)
|
|
self.expanded_block_table_buffer = torch.zeros(
|
|
(
|
|
scheduler_config.max_num_batched_tokens,
|
|
max_num_blocks_per_req,
|
|
),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
|
|
# See: DeepGMM/csrc/apis/attention.hpp
|
|
self.scheduler_metadata_buffer = torch.empty(
|
|
(self.num_sms + 1, 2), dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
def build_one_prefill_chunk(
|
|
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
|
):
|
|
prefill_query_start_loc = (
|
|
query_start_loc_cpu[reqs_start : reqs_end + 1]
|
|
- query_start_loc_cpu[reqs_start]
|
|
)
|
|
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
|
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
|
)
|
|
token_start = query_start_loc_cpu[reqs_start].item()
|
|
token_end = query_start_loc_cpu[reqs_end].item()
|
|
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
|
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
|
|
token_to_seq = torch.repeat_interleave(
|
|
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
|
|
).to(self.device)
|
|
assert total_seq_lens <= self.max_prefill_buffer_size
|
|
cu_seq_lens = (
|
|
torch.cat(
|
|
[
|
|
torch.zeros(1, dtype=torch.int32),
|
|
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
|
]
|
|
)
|
|
.to(torch.int32)
|
|
.to(self.device)
|
|
)
|
|
cu_seqlens_q = prefill_query_start_loc.to(torch.int32).to(self.device)
|
|
max_context_len = seq_lens_cpu[reqs_start:reqs_end].max().item()
|
|
# max_q_len is the maximum query length among all batches in this chunk
|
|
# prefill_query_start_loc is cumsum of lengths with shape [batch+1]
|
|
max_q_len = (prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]).max().item()
|
|
return DeepseekV32IndexerPrefillChunkMetadata(
|
|
cu_seqlen_ks=cu_seqlen_ks,
|
|
cu_seqlen_ke=cu_seqlen_ke,
|
|
cu_seq_lens=cu_seq_lens,
|
|
token_to_seq=token_to_seq,
|
|
total_seq_lens=total_seq_lens,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
block_table=block_table[reqs_start:reqs_end],
|
|
token_start=token_start,
|
|
token_end=token_end,
|
|
num_reqs=reqs_end - reqs_start,
|
|
max_context_len=max_context_len,
|
|
max_q_len=max_q_len,
|
|
max_kv_len=max_context_len
|
|
)
|
|
|
|
def build_decode_metadata(
|
|
self, common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
|
|
):
|
|
decode_lens_cpu = torch.diff(
|
|
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
|
)
|
|
assert (
|
|
decode_lens_cpu.max().item()
|
|
== decode_lens_cpu.min().item()
|
|
== 1
|
|
), "Only support single token decode in dsa_indexer backend"
|
|
|
|
# Calculate decode metadata parameters
|
|
seq_lens_decode = common_attn_metadata.seq_lens_cpu[:num_decodes]
|
|
max_context_len = seq_lens_decode.max().item()
|
|
max_kv_len = max_context_len
|
|
max_q_len = 1 # Single token decode
|
|
|
|
# Create cu_seqlens_q: cumulative sum of query lengths (all 1s)
|
|
cu_seqlens_q = torch.arange(
|
|
num_decodes + 1, dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
# Create cu_seqlens_kv and related tensors using kv_spans_from_batches
|
|
decode_query_start_loc = torch.arange(
|
|
num_decodes + 1, dtype=torch.long
|
|
)
|
|
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
|
decode_query_start_loc, seq_lens_decode, self.device
|
|
)
|
|
|
|
cu_seqlens_kv = torch.cat(
|
|
[
|
|
torch.zeros(1, dtype=torch.int32, device=self.device),
|
|
torch.cumsum(seq_lens_decode.to(self.device), dim=0)
|
|
.to(torch.int32),
|
|
]
|
|
)
|
|
|
|
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
|
block_table=common_attn_metadata.block_table_tensor[
|
|
:num_decodes, ...
|
|
],
|
|
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
|
decode_lens=decode_lens,
|
|
requires_padding=(
|
|
decode_lens_cpu.max() > decode_lens_cpu.min()
|
|
).item(),
|
|
use_large_context_topk=use_large_context_topk,
|
|
offsets=offsets,
|
|
cu_seqlen_ks=cu_seqlen_ks,
|
|
cu_seqlen_ke=cu_seqlen_ke,
|
|
cu_seqlens_kv=cu_seqlens_kv,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
max_context_len=max_context_len,
|
|
max_q_len=max_q_len,
|
|
max_kv_len=max_kv_len,
|
|
# schedule_metadata=self.scheduler_metadata_buffer,
|
|
)
|
|
return decode_metadata
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> DeepseekV32IndexerMetadata:
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_tokens = common_attn_metadata.num_actual_tokens
|
|
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
|
split_decodes_and_prefills(
|
|
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
|
)
|
|
)
|
|
|
|
assert num_decodes + num_prefills == num_reqs
|
|
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
|
|
|
prefill_metadata = None
|
|
if num_prefills > 0:
|
|
chunk_seq_ids = split_prefill_chunks(
|
|
common_attn_metadata.seq_lens_cpu[num_decodes:],
|
|
self.max_prefill_buffer_size,
|
|
request_offset=num_decodes,
|
|
)
|
|
chunks = [
|
|
self.build_one_prefill_chunk(
|
|
reqs_start,
|
|
reqs_end,
|
|
query_start_loc_cpu,
|
|
common_attn_metadata.seq_lens_cpu,
|
|
common_attn_metadata.block_table_tensor,
|
|
)
|
|
for reqs_start, reqs_end in chunk_seq_ids
|
|
]
|
|
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
|
chunks=chunks,
|
|
)
|
|
|
|
decode_metadata = None
|
|
if num_decodes > 0:
|
|
torch.diff(
|
|
common_attn_metadata.query_start_loc[: num_decodes + 1],
|
|
out=self.decode_lens_buffer[:num_decodes],
|
|
)
|
|
decode_lens = self.decode_lens_buffer[:num_decodes]
|
|
decode_lens_cpu = torch.diff(
|
|
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
|
)
|
|
|
|
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
|
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
|
|
|
# Padded CUDA graph requests have block_table entries of -1.
|
|
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
|
# This is safe because padded requests have seq_lens=0, so the
|
|
# kernel produces no meaningful output for those rows.
|
|
block_table.clamp_(min=0)
|
|
|
|
max_decode_len = int(decode_lens_cpu.max().item())
|
|
if max_decode_len > 1:
|
|
# Flatten multi-token decode requests into single-token
|
|
# batch entries, expanding seq_lens and block tables so
|
|
# the kernel always sees next_n=1.
|
|
|
|
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
|
|
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
|
|
# The context lengths are therefore
|
|
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
|
|
|
|
# 3 + 1 + 4 + 0 = 8
|
|
actual_expanded = int(decode_lens_cpu.sum().item())
|
|
|
|
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
|
expanded_base = torch.repeat_interleave(
|
|
seq_lens - decode_lens, decode_lens
|
|
)
|
|
|
|
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
|
expanded_starts = torch.repeat_interleave(
|
|
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
|
|
)
|
|
|
|
# [0, 1, 2, 0, 0, 1, 2, 3]
|
|
positions_within = (
|
|
self.arange_buffer[:actual_expanded] - expanded_starts
|
|
)
|
|
|
|
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
|
|
self.expanded_seq_lens_buffer[:actual_expanded] = (
|
|
expanded_base + positions_within + 1
|
|
)
|
|
self.expanded_seq_lens_buffer[actual_expanded:] = 0
|
|
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
|
|
|
|
# Give each of the flattened entries the same block table row as the
|
|
# original request.
|
|
self.expanded_block_table_buffer[:actual_expanded] = (
|
|
torch.repeat_interleave(block_table, decode_lens, dim=0)
|
|
)
|
|
if actual_expanded < num_decode_tokens:
|
|
self.expanded_block_table_buffer[
|
|
actual_expanded:num_decode_tokens, 0
|
|
] = 0
|
|
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
|
|
|
|
# All reqs now have decode_len=1
|
|
self.decode_lens_buffer[:num_decode_tokens] = 1
|
|
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
|
offsets = None
|
|
batch_size = num_decode_tokens
|
|
else:
|
|
next_n = 1 + self.num_speculative_tokens
|
|
if next_n > 1:
|
|
offsets = torch.arange(
|
|
next_n, device=self.device, dtype=torch.int32
|
|
)
|
|
else:
|
|
offsets = None
|
|
batch_size = num_decodes
|
|
|
|
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
|
if current_platform.is_cuda() and is_deep_gemm_supported():
|
|
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
|
seq_lens,
|
|
self.kv_cache_spec.block_size,
|
|
self.num_sms,
|
|
)
|
|
|
|
# Decide which top-k kernel to use based on batch size and sequence length
|
|
# Decision logic based on micro-benchmark results:
|
|
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
|
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
|
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
|
use_large_context_topk = batch_size <= 128 and _is_large_context
|
|
|
|
# decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
|
# block_table=block_table,
|
|
# seq_lens=seq_lens,
|
|
# decode_lens=decode_lens,
|
|
# requires_padding=False,
|
|
# # schedule_metadata=self.scheduler_metadata_buffer,
|
|
# use_large_context_topk=use_large_context_topk,
|
|
# offsets=offsets,
|
|
# )
|
|
decode_metadata = self.build_decode_metadata(
|
|
common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
|
|
)
|
|
|
|
attn_metadata = DeepseekV32IndexerMetadata(
|
|
seq_lens=common_attn_metadata.seq_lens,
|
|
num_reqs=common_attn_metadata.num_reqs,
|
|
max_query_len=common_attn_metadata.max_query_len,
|
|
max_seq_len=common_attn_metadata.max_seq_len,
|
|
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
|
query_start_loc=common_attn_metadata.query_start_loc,
|
|
slot_mapping=common_attn_metadata.slot_mapping,
|
|
head_dim=128,
|
|
num_decodes=num_decodes,
|
|
num_decode_tokens=num_decode_tokens,
|
|
num_prefills=num_prefills,
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
prefill=prefill_metadata,
|
|
decode=decode_metadata,
|
|
)
|
|
|
|
# if get_tensor_model_parallel_rank() == 0:
|
|
# logger.info(f"attn_metadata: {attn_metadata}")
|
|
return attn_metadata
|