from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, Generic, List, Optional, Dict, TypeVar, Type import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType, ) from vllm.attention.backends.utils import get_mla_dims from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.attention.backends.utils import get_mla_dims from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.config import VllmConfig from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonMetadata) from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills) from vllm.v1.attention.backends.mla.common import MLACommonImpl if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size logger = init_logger(__name__) M = TypeVar("M", bound=MLACommonMetadata) def vacc_paged_attention_naive( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, block_table: torch.Tensor, # seq_lens: torch.Tensor, seq_lens: int, out: Optional[torch.Tensor] = None, sm_scale = -1 ) -> torch.Tensor: # gurantee batch=1 perf if len(seq_lens) == 1: k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]] v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]] attn_out = torch.vacc.scaled_dot_product_attention( query=query, key=k, value=v, attn_mask=None, dropout_p=0, is_causal=False, is_train=False, recompute=False, flash_attention=False, sm_scale=sm_scale ) else: # t0 = time.time() attn_outs = [] for i in range(len(seq_lens)): k_slices = key_cache[block_table[i], :, :, :] k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0) k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]] v_slices = value_cache[block_table[i], :, :, :] v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0) v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]] attn_out = torch.vacc.scaled_dot_product_attention( query=query[i:i+1,:,:], key=k, value=v, attn_mask=None, dropout_p=0, is_causal=False, is_train=False, recompute=False, flash_attention=False, sm_scale=sm_scale ) attn_outs.append(attn_out) attn_out = torch.cat(attn_outs, dim=0) # print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}') return attn_out @dataclass class MLACommonPrefillMetadata: """ Prefill Specific Metadata """ @dataclass class ChunkedContextMetadata: # New for MLA (compared to FlashAttention) # For handling chunked prefill cu_seq_lens: torch.Tensor starts: torch.Tensor seq_tot: list[int] max_seq_lens: list[int] workspace: torch.Tensor block_tables: torch.Tensor #block_table => block_tables 兼容v0 query_start_loc: torch.Tensor # max_query_len: int seq_lens: list[int] chunked_context: Optional[ChunkedContextMetadata] = None @dataclass class MLACommonDecodeMetadata: block_tables: torch.Tensor #block_table => block_tables 兼容v0 seq_lens: torch.Tensor class VACCMLAMetadata(MLACommonMetadata): """Metadata for VACCMLAMetadata. NOTE: Any python object stored here is not updated when it is cuda-graph replayed. If you have values that need to be changed dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool # Maximum query length in the batch. max_query_len: Optional[int] = None input_positions: Optional[torch.Tensor] = None # Max number of query tokens among request in the batch. max_decode_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. query_start_loc: Optional[torch.Tensor] = None # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None _cached_prefill_metadata: Optional["VACCMLAMetadata"] = None _cached_decode_metadata: Optional["VACCMLAMetadata"] = None num_prefill_tokens: int num_kv_splits: int = 4 # TODO(lucas) add heuristic attn_logits: Optional[torch.Tensor] = None req_idx: Optional[torch.Tensor] = None # The dimension of the attention heads head_dim: Optional[int] = None def __post_init__(self): supported_head_sizes = VACCMLABackend.get_supported_head_sizes() if self.head_dim is not None and self.head_dim \ not in supported_head_sizes: raise ValueError( f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") @property def prefill_metadata(self) -> Optional["VACCMLAMetadata"]: if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata assert self.seq_lens is not None assert self.seq_lens_tensor is not None # Compute some attn_metadata fields which default to None query_start_loc = (None if self.query_start_loc is None else self.query_start_loc[:self.num_prefills + 1]) slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[:self.num_prefill_tokens]) seq_lens = (None if self.seq_lens is None else self.seq_lens[:self.num_prefills]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills]) seq_start_loc = (None if self.seq_start_loc is None else self.seq_start_loc[:self.num_prefills + 1]) context_lens_tensor = (None if self.context_lens_tensor is None else self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) input_positions = (None if self.input_positions is None else self.input_positions[:self.num_prefill_tokens]) self._cached_prefill_metadata = VACCMLAMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, enable_kv_scales_calculation=self.enable_kv_scales_calculation, input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=None, max_decode_seq_len=0, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, head_dim=self.head_dim) return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["VACCMLAMetadata"]: if self.num_decode_tokens == 0: return None if self._cached_decode_metadata is not None: return self._cached_decode_metadata assert self.seq_lens_tensor is not None # Compute some attn_metadata fields which default to None slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[self.num_prefill_tokens:]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) input_positions = (None if self.input_positions is None else self.input_positions[self.num_prefill_tokens:]) self._cached_decode_metadata = VACCMLAMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=self.seq_lens, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, # Batch may be composed of prefill|decodes, adjust query start # indices to refer to the start of decodes. E.g. # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. query_start_loc=(self.query_start_loc[self.num_prefills:] - self.query_start_loc[self.num_prefills]) if self.query_start_loc is not None else None, seq_start_loc=self.seq_start_loc[self.num_prefills:] if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, input_positions=input_positions, head_dim=self.head_dim) return self._cached_decode_metadata def advance_step(self, model_input: "ModelInputForVACCWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], block_size: int, num_seqs: int, num_queries: int, turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries if num_seqs != num_queries: assert num_seqs > num_queries assert self.use_cuda_graph if turn_prefills_into_decodes: # When Mutli-Step is enabled with Chunked-Prefill, prefills and # decodes are scheduled together. In the first step, all the # prefills turn into decodes. This update reflects that # conversion. assert self.num_decode_tokens + self.num_prefills == num_seqs self.num_decode_tokens += self.num_prefills self.num_prefills = 0 # self.num_prefill_tokens = 0 # self.max_prefill_seq_len = 0 self.max_query_len = 1 self.slot_mapping = self.slot_mapping[:num_seqs] else: assert self.seq_lens is not None assert self.max_decode_seq_len == max(self.seq_lens) assert self.num_prefills == 0 assert self.num_prefill_tokens == 0 assert self.num_decode_tokens == num_seqs assert self.slot_mapping.shape == (num_seqs, ) assert self.seq_lens is not None assert len(self.seq_lens) == num_seqs assert self.seq_lens_tensor is not None assert self.seq_lens_tensor.shape == (num_seqs, ) # assert self.max_query_len == 1 # assert self.max_prefill_seq_len == 0 assert self.query_start_loc is not None assert self.query_start_loc.shape == (num_queries + 1, ) assert self.seq_start_loc is not None assert self.seq_start_loc.shape == (num_seqs + 1, ) assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) assert self.block_tables is not None assert self.block_tables.shape[0] == num_seqs # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): self.seq_lens[i] += 1 # self.max_decode_seq_len = None ops.advance_step_flashattn(num_seqs=num_seqs, num_queries=num_queries, block_size=block_size, input_tokens=model_input.input_tokens, sampled_token_ids=sampled_token_ids, input_positions=model_input.input_positions, seq_lens=self.seq_lens_tensor, slot_mapping=self.slot_mapping, block_tables=self.block_tables) class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, # blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **kwargs) -> None: # super().__init__(num_heads, head_size, scale, num_kv_heads, # alibi_slopes, sliding_window, kv_cache_dtype, # logits_soft_cap, attn_type, # kv_sharing_target_layer_name, **kwargs) self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype # print('kwargs', kwargs) # 手动写老版本的继承 self.q_lora_rank = kwargs['q_lora_rank'] if 'q_lora_rank' in kwargs else None self.kv_lora_rank = kwargs['kv_lora_rank'] if 'kv_lora_rank' in kwargs else None self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] if 'qk_nope_head_dim' in kwargs else None self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None self.v_head_dim = kwargs['v_head_dim'] if 'v_head_dim' in kwargs else None self.rotary_emb = kwargs['rotary_emb'] if 'rotary_emb' in kwargs else None self.q_proj = kwargs['q_proj'] if 'q_proj' in kwargs else None self.kv_b_proj = kwargs['kv_b_proj'] if 'kv_b_proj' in kwargs else None self.o_proj = kwargs['o_proj'] if 'o_proj' in kwargs else None unsupported_features = [ alibi_slopes, sliding_window, logits_soft_cap ] if any(unsupported_features): raise NotImplementedError( "VACCMLAImpl does not support one of the following: " "alibi_slopes, sliding_window, blocksparse_params, " "logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "VACCMLAImpl") def extract_weights(self): weights = {} if hasattr(self, 'W_Q'): weights["W_Q"] = self.W_Q if hasattr(self, 'W_Q_scales'): weights["W_Q_scales"] = self.W_Q_scales if hasattr(self, 'W_QR'): weights['W_QR'] = self.W_QR if hasattr(self, 'W_QR_scales'): weights["W_QR_scales"] = self.W_QR_scales if hasattr(self, 'W_Q_QR'): weights["W_Q_QR"] = self.W_Q_QR if hasattr(self, 'W_Q_QR_scales'): weights["W_Q_QR_scales"] = self.W_Q_QR_scales if hasattr(self, 'W_UK'): weights['W_UK'] = self.W_UK if hasattr(self, 'W_UK_scales'): weights['W_UK_scales'] = self.W_UK_scales if hasattr(self, 'W_Q_UK_scales'): weights['W_Q_UK_scales'] = self.W_Q_UK_scales if hasattr(self, 'W_UV'): weights['W_UV'] = self.W_UV if hasattr(self, 'W_UV_scales'): weights['W_UV_scales'] = self.W_UV_scales if hasattr(self, 'W_UV_O'): weights['W_UV_O'] = self.W_UV_O if hasattr(self, 'W_UV_O_scales'): weights['W_UV_O_scales'] = self.W_UV_O_scales return weights def _forward_prefill( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: VACCMLAMetadata, ) -> torch.Tensor: assert isinstance(attn_metadata, VACCMLAV1Metadata) kv_nope = self.kv_b_proj(kv_c_normed)[0]\ .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) v = v.contiguous() # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim # v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], # value=0) # attn_output = torch.vacc.scaled_dot_product_attention( # query=q, # key=k, # value=v_padded, # attn_mask=None, # dropout_p=0, # is_causal=True, # is_train=False, # recompute=False, # flash_attention=True, # sm_scale=self.scale # ) # attn_output = attn_output\ # .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ # .reshape(-1, self.num_heads * v.shape[-1]) seq_lens = attn_metadata.prefill_metadata.seq_lens if len(seq_lens) == 1: # Vacc supports different head dim of v and qk. attn_output = torch.vacc.scaled_dot_product_attention( query=q, key=k, value=v, attn_mask=None, dropout_p=0, is_causal=True, is_train=False, recompute=False, flash_attention=False, sm_scale=self.scale ) attn_out = attn_output.view(-1, self.num_heads * v.shape[-1]) else: attn_outs = [] start = 0 for seq in seq_lens: end = start + seq attn_out = torch.vacc.scaled_dot_product_attention( query=q[start:end, :], key=k[start:end, :], value=v[start:end, :], attn_mask=None, dropout_p=0, is_causal=True, is_train=False, recompute=False, flash_attention=False, sm_scale=self.scale ) start = end attn_outs.append(attn_out) attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1]) return self.o_proj(attn_out)[0] def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: VACCMLAMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") decode_meta = attn_metadata.decode_metadata assert decode_meta is not None B = q_nope.shape[0] q = torch.cat([q_nope, q_pe], dim=-1) o = torch.zeros(B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device) # Add a head dim of 1 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) # print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ") kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] # Run MQA using paged_attention # o = torch.vacc.paged_attention( # query=q, # key_cache=kv_c_and_k_pe_cache, # value_cache=kv_c_cache, # block_table=decode_meta.block_tables, # seq_len=decode_meta.seq_lens_tensor, # out=o, # sm_scale=self.scale # ) # Run MQA using spda # t0 = time.time() o = vacc_paged_attention_naive( q, kv_c_and_k_pe_cache, kv_c_cache, block_table = decode_meta.block_tables, # seq_lens = decode_meta.seq_lens_tensor, seq_lens=decode_meta.seq_lens, out = o, sm_scale=self.scale) return self._v_up_proj_and_o_proj(o) # patch from MLACommonBackend class VACCMLABackend(AttentionBackend): accept_output_buffer: bool = False @staticmethod def get_name() -> str: return "TRITON_MLA_VLLM_V1" @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return VACCMLAMetadata @staticmethod def get_impl_cls() -> Type["VACCMLAImpl"]: return VACCMLAImpl @staticmethod def get_builder_cls() -> type["MLAVaccMetadataBuilder"]: return MLAVaccMetadataBuilder @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() if head_size not in supported_head_sizes: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes.") D = TypeVar("D", bound=MLACommonDecodeMetadata) # patch from MLACommonMetadata @dataclass class VACCMLAV1Metadata(Generic[D]): """Metadata for MLACommon. NOTE: Please read the comment at the top of the file before trying to understand this class """ # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor slot_mapping: torch.Tensor # New for MLA (compared to FlashAttention) # For handling prefill decode split num_decodes: int num_decode_tokens: int num_prefills: int num_prefill_tokens: int # The dimension of the attention heads head_dim: Optional[int] = None decode_metadata: Optional[D] = None prefill_metadata: Optional[MLACommonPrefillMetadata] = None def __post_init__(self): if self.head_dim is not None: MLACommonBackend.validate_head_size(self.head_dim) class MLAVaccMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ reorder_batch_threshold: ClassVar[int] = 2 #TODO 区分 prefill decode 阈值 def __init__(self, # runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, metadata_cls: Optional[type[M]] = None): self._global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)) self.metadata_cls = metadata_cls \ if metadata_cls is not None else VACCMLAV1Metadata self.kv_cache_spec = kv_cache_spec scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config self.device = device # self.runner = runner # scheduler_config = runner.scheduler_config # model_config = runner.model_config # cache_config = runner.cache_config self.chunked_prefill_enabled = False self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() self.kv_cache_spec = kv_cache_spec # Dont try to access the runner on AMD if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size # if self.chunked_prefill_enabled: # self.chunked_prefill_workspace_size = min( # # Max sure there is enough for 8 full length request or at least # # 4 pages of cache per request # max( # 8 * model_config.max_model_len, 4 * # scheduler_config.max_num_seqs * cache_config.block_size), # # For long-context models try not to over-allocate limiting # # kv-cache space, limiting it to 64k tokens, # # which would result in the workspace being: # # 2*(576)*(64*1024) = 144mb # # (assuming 576 MLA head dim, and fp16) # # which would result in up-projected context being # # 2*(192*128)*(64*1024) = 3gb # # (assuming 192 QK head dim, 128 heads, and fp16) # 128 * 1024) # assert self.chunked_prefill_workspace_size >= \ # scheduler_config.max_num_seqs * cache_config.block_size # self.chunked_prefill_workspace = torch.empty( # (self.chunked_prefill_workspace_size, # model_config.get_head_size()), # dtype=model_config.dtype, # device=runner.device, # ) # self.block_table = block_table def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are and # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests # where attention is likely memory-bound and "prefill" to mean requests # where attention is likely compute-bound, TODO(lucas): figure out a # better naming here) decodes = [] prefills = [] num_decode_tokens = 0 num_prefill_tokens = 0 for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] # for now treat 1 scheduled token as "decode" even if its not, # we should update this to something like < 8 in the future but # currently the TritonMLA._forward_decode only supports # num_tokens = 1 if scheduler_output.scheduled_cached_reqs.num_computed_tokens != []: decodes.append(i) num_decode_tokens += num_tokens else: prefills.append(i) num_prefill_tokens += num_tokens # We hope that this is fairly minimal since decodes # should be around for a number of iterations so hopefully they are # relatively stationary (and new request are generally appended to the # persistent batch so already should be at the back) # To achieve this we loop over the decodes in descending order and # the prefills in ascending order. We swap decodes from the "back" # i.e. past where the last decode should be in the reodorered with # prefills from the front of the batch. # `decodes` and `prefills` are already in ascending order just based on # the above loop num_decodes = len(decodes) num_prefills = len(prefills) modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch decode_idx = decodes[num_decodes - i] if decode_idx < num_decodes: break input_batch.swap_states(prefills[i - 1], decode_idx) modified_batch = True # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this self._num_decodes = num_decodes self._num_prefills = num_prefills self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens return modified_batch def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor): return MLACommonDecodeMetadata( block_tables=block_table_tensor, seq_lens=seq_lens, ) # def build_for_cudagraph_capture( # self, common_attn_metadata: CommonAttentionMetadata) -> M: # """ # This method builds the metadata for full cudagraph capture. # Currently, only decode is supported for full cudagraphs with MLA. # """ # m = common_attn_metadata # assert m.num_reqs == m.num_actual_tokens, \ # "MLA only supports decode-only full CUDAGraph capture. " \ # "Make sure all cudagraph capture sizes <= max_num_seq." # # m.max_query_len = 1 # decode-only # # Update state usually set in reorder_batch. # self._num_decodes = m.num_reqs # self._num_decode_tokens = m.num_actual_tokens # self._num_prefills = 0 # self._num_prefill_tokens = 0 # return self.build(0, m) def append_seqlen(self, seq_len: list[int], all_len: int): # print('append_seqlen seq_len', seq_len) # print('append_seqlen all_len', all_len) if all_len > len(seq_len) and all_len % len(seq_len) == 0: new_seq_len = [] mtp_num = all_len // len(seq_len) for start_len in seq_len: for i in range(1,1+mtp_num): new_seq_len.append(start_len-mtp_num+i) return new_seq_len return seq_len def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> M: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens # max_query_len = common_attn_metadata.max_query_len # assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. # device = self.runner.device # block_table = self.block_table # block_table_tensor = block_table.get_device_tensor()#[:num_reqs] block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping #decode common_attn_metadata: CommonAttentionMetadata(query_start_loc=tensor([0, 2], device='vacc:20', dtype=torch.int32), query_start_loc_cpu=tensor([0, 2], dtype=torch.int32), seq_lens=[40], seq_lens_cpu=[40, 0, 0, 0], num_computed_tokens_cpu=tensor([38], dtype=torch.int32), num_reqs=1, num_actual_tokens=2, max_query_len=2, max_seq_len=40, block_table_tensor=tensor([[1536, 1537, 1538, ..., 0, 0, 0]], device='vacc:20', # block_table.slot_mapping[:num_actual_tokens].copy_( # block_table.slot_mapping_cpu[:num_actual_tokens], # non_blocking=True) # # block_table.slot_mapping[num_actual_tokens:].fill_(-1) # slot_mapping = block_table.slot_mapping[:num_actual_tokens] # num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ # split_decodes_and_prefills(common_attn_metadata, # decode_threshold=self.reorder_batch_threshold) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = 0, 0, 0, 0 token_nums = common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1] if token_nums.max().item() > self.reorder_batch_threshold: num_prefills = num_reqs num_prefill_tokens = num_actual_tokens else: num_decodes = num_reqs num_decode_tokens = num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens prefill_metadata = None if num_prefills > 0: reqs_start = num_decodes # prefill_start # context_lens_cpu = self.runner.input_batch.\ # num_computed_tokens_cpu_tensor[reqs_start:num_reqs] # max_context_len_cpu = context_lens_cpu.max().item() # num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None # if self.chunked_prefill_enabled and self._num_prefills > 0 \ # and max_context_len_cpu > 0: # dont support chunked prefill prefill_metadata = MLACommonPrefillMetadata( block_tables=block_table_tensor[reqs_start:reqs_start+num_prefills, ...], query_start_loc=prefill_query_start_loc, # max_query_len=None, chunked_context=chunked_context_metadata, seq_lens=seq_lens[:num_prefills], ) if not isinstance(seq_lens, list): # TODO init set list in init: vllm/v1/spec_decode/eagle.py seq_lens = seq_lens.tolist() decode_metadata = None if num_decodes > 0: block_num_per_group = env_blk_grp_size // 16 block_table_tensor_new = block_table_tensor[:num_decodes, ::block_num_per_group].contiguous() seq_lens_new = self.append_seqlen(seq_lens[:slot_mapping.shape[-1]], slot_mapping.shape[-1]) if slot_mapping.shape[-1] > num_decodes: mtp_numbers = [query_start_loc[i+1]-query_start_loc[i] for i in range(len(query_start_loc)-1)] #query_start_loc[1:] - query_start_loc[:-1] block_table_tensor_list = [] for bi,mtp_number in enumerate(mtp_numbers): for _ in range(mtp_number): block_table_tensor_list.append(block_table_tensor_new[bi:bi+1]) block_table_tensor_new = torch.concatenate(block_table_tensor_list, 0) decode_metadata = self._build_decode( block_table_tensor=block_table_tensor_new, seq_lens=seq_lens_new, ) return self.metadata_cls( num_actual_tokens=num_actual_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, head_dim=self.model_config.get_head_size(), # prefill_seq_lens=seq_lens[:self._num_prefills].tolist(), # device to host, todo optimiz # MLACommonMetadata Chunk prefill specific num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, prefill_metadata=prefill_metadata, decode_metadata=decode_metadata, )