################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from dataclasses import dataclass import torch from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm_br.config.compilation import SUPAGraphMode @dataclass class SUPACommonAttentionMetadata(CommonAttentionMetadata): """ Attention metadata attributes that can be shared by layers in different KV cache groups and thus having different block table. """ query_start_loc: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" seq_lens: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" num_actual_reqs: torch.Tensor | None = None """(1,), numble of actual request in the batch""" supagraph_runtime_mode: SUPAGraphMode | None = None context_lens: torch.Tensor | None = None """(batch_size,), the length of each request including computed tokens only""" max_decode_seq_len: int | None = None """The maximum length of the decoded sequence in the batch.""" seq_start_loc: torch.Tensor | None = None """(batch_size + 1,), the start location of each request in sequence Tensor. This is used to compute the sequence length of each request. If not provided, it will be computed from seq_lens."""