diff --git a/vllm_kunlun/ops/__init__.py b/vllm_kunlun/ops/__init__.py index 1776a76..7facd71 100644 --- a/vllm_kunlun/ops/__init__.py +++ b/vllm_kunlun/ops/__init__.py @@ -15,20 +15,22 @@ # This file is a part of the vllm-ascend project. # -# embedding -import vllm_kunlun.ops.rotary_embedding -import vllm_kunlun.ops.vocab_parallel_embedding - -# quantization -import vllm_kunlun.ops.quantization.awq -import vllm_kunlun.ops.quantization.gptq -import vllm_kunlun.ops.quantization.moe_wna16 -import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors -import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe -import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm -import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear +import vllm_kunlun.ops.fused_moe.layer # base layers import vllm_kunlun.ops.layernorm import vllm_kunlun.ops.linear -import vllm_kunlun.ops.fused_moe.layer + +# quantization +import vllm_kunlun.ops.quantization.awq +import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors +import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe +import vllm_kunlun.ops.quantization.gptq +import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear +import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm +import vllm_kunlun.ops.quantization.moe_wna16 + +# embedding +import vllm_kunlun.ops.rotary_embedding +import vllm_kunlun.ops.vocab_parallel_embedding +import vllm_kunlun.v1.sample.spec_decode.eagle diff --git a/vllm_kunlun/v1/attention/backends/mla/indexer.py b/vllm_kunlun/v1/attention/backends/mla/indexer.py index ab2a383..b966bca 100644 --- a/vllm_kunlun/v1/attention/backends/mla/indexer.py +++ b/vllm_kunlun/v1/attention/backends/mla/indexer.py @@ -1,17 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional + +import torch from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, - split_decodes_and_prefills) -from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks, - DeepseekV32IndexerMetadataBuilder, - kv_spans_from_batches) +from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV32IndexerMetadataBuilder, + kv_spans_from_batches, + split_prefill_chunks, +) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + split_decodes_and_prefills, +) logger = init_logger(__name__) + @dataclass class DeepseekV32IndexerPrefillChunkMetadata: block_table: torch.Tensor @@ -32,6 +38,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: class DeepseekV32IndexerPrefillMetadata: chunks: list[DeepseekV32IndexerPrefillChunkMetadata] + @dataclass class DeepSeekV32IndexerDecodeMetadata: block_table: torch.Tensor @@ -70,26 +77,36 @@ class DeepseekV32IndexerMetadata: decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None -def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end, - query_start_loc_cpu, seq_lens_cpu, - block_table): - prefill_query_start_loc = query_start_loc_cpu[ - reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start] + +def kunlun_build_one_prefill_chunk( + self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table +): + prefill_query_start_loc = ( + query_start_loc_cpu[reqs_start : reqs_end + 1] - query_start_loc_cpu[reqs_start] + ) cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( - prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], - self.device) + prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device + ) token_start = query_start_loc_cpu[reqs_start].item() token_end = query_start_loc_cpu[reqs_end].item() total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() assert total_seq_lens <= self.max_prefill_buffer_size - cu_seq_lens = torch.cat([ - torch.zeros(1, dtype=torch.int32), - seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0) - ]).to(torch.int32).to(self.device) + cu_seq_lens = ( + torch.cat( + [ + torch.zeros(1, dtype=torch.int32), + seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0), + ] + ) + .to(torch.int32) + .to(self.device) + ) seq_len_q = token_end - token_start seq_len_kv = total_seq_lens context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device) - context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device) + context_k_lens = torch.tensor( + [0, seq_len_kv], dtype=torch.int32, device=self.device + ) context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu") context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu") @@ -107,85 +124,103 @@ def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end, context_k_lens=context_k_lens, context_k_lens_cpu=context_k_lens_cpu, ) -def kunlun_build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> DeepseekV32IndexerMetadata: - num_reqs = common_attn_metadata.num_reqs - num_tokens = common_attn_metadata.num_actual_tokens - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) +def kunlun_build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, +) -> DeepseekV32IndexerMetadata: - assert num_decodes + num_prefills == num_reqs - assert num_decode_tokens + num_prefill_tokens == num_tokens + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens - prefill_metadata = None - if num_prefills > 0: - chunk_seq_ids = split_prefill_chunks( + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + chunk_seq_ids = split_prefill_chunks( + common_attn_metadata.seq_lens_cpu, + self.max_prefill_buffer_size, + num_decodes, + ) + chunks = [ + self.build_one_prefill_chunk( + reqs_start, + reqs_end, + query_start_loc_cpu, common_attn_metadata.seq_lens_cpu, - self.max_prefill_buffer_size, - num_decodes, + common_attn_metadata.block_table_tensor, ) - chunks = [ - self.build_one_prefill_chunk( - reqs_start, reqs_end, query_start_loc_cpu, - common_attn_metadata.seq_lens_cpu, - common_attn_metadata.block_table_tensor) - for reqs_start, reqs_end in chunk_seq_ids - ] - prefill_metadata = DeepseekV32IndexerPrefillMetadata( - chunks=chunks, ) - - decode_metadata = None - if num_decodes > 0: - torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1], - out=self.decode_lens_buffer[:num_decodes]) - decode_lens = self.decode_lens_buffer[:num_decodes] - decode_lens_cpu = torch.diff( - common_attn_metadata.query_start_loc_cpu[:num_decodes + 1]) - - # Use CPU to avoid GPU sync; breaking async scheduling - requires_padding = (decode_lens_cpu.max() - > decode_lens_cpu.min()).item() - - seq_lens = common_attn_metadata.seq_lens[:num_decodes] - - decode_metadata = DeepSeekV32IndexerDecodeMetadata( - block_table=common_attn_metadata. - block_table_tensor[:num_decodes, ...], - seq_lens=common_attn_metadata.seq_lens[:num_decodes], - seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(), - decode_lens=decode_lens, - requires_padding=requires_padding, - schedule_metadata=self.scheduler_metadata_buffer, - ) - - attn_metadata = DeepseekV32IndexerMetadata( - seq_lens=common_attn_metadata.seq_lens, - seq_lens_cpu=common_attn_metadata.seq_lens.cpu(), - num_reqs=common_attn_metadata.num_reqs, - max_query_len=common_attn_metadata.max_query_len, - max_seq_len=common_attn_metadata.max_seq_len, - num_actual_tokens=common_attn_metadata.num_actual_tokens, - query_start_loc=common_attn_metadata.query_start_loc, - slot_mapping=common_attn_metadata.slot_mapping, - head_dim=128, - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, - prefill=prefill_metadata, - decode=decode_metadata, + for reqs_start, reqs_end in chunk_seq_ids + ] + prefill_metadata = DeepseekV32IndexerPrefillMetadata( + chunks=chunks, ) - # if get_tensor_model_parallel_rank() == 0: - # logger.info(f"attn_metadata: {attn_metadata}") - return attn_metadata + decode_metadata = None + if num_decodes > 0: + torch.diff( + common_attn_metadata.query_start_loc[: num_decodes + 1], + out=self.decode_lens_buffer[:num_decodes], + ) + decode_lens = self.decode_lens_buffer[:num_decodes] + decode_lens_cpu = torch.diff( + common_attn_metadata.query_start_loc_cpu[: num_decodes + 1] + ) -DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk -DeepseekV32IndexerMetadataBuilder.build = kunlun_build \ No newline at end of file + # Use CPU to avoid GPU sync; breaking async scheduling + requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() + + seq_lens = common_attn_metadata.seq_lens[:num_decodes] + + decode_metadata = DeepSeekV32IndexerDecodeMetadata( + block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...], + seq_lens=common_attn_metadata.seq_lens[:num_decodes], + seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(), + decode_lens=decode_lens, + requires_padding=requires_padding, + schedule_metadata=self.scheduler_metadata_buffer, + ) + + attn_metadata = DeepseekV32IndexerMetadata( + seq_lens=common_attn_metadata.seq_lens, + seq_lens_cpu=common_attn_metadata.seq_lens.cpu(), + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + head_dim=128, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + # if get_tensor_model_parallel_rank() == 0: + # logger.info(f"attn_metadata: {attn_metadata}") + return attn_metadata + + +DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk = ( + kunlun_build_one_prefill_chunk +) +DeepseekV32IndexerMetadataBuilder.build = kunlun_build + +# Monkey patch: Upgrade cudagraph_support to UNIFORM_BATCH for spec-decode compatibility +from vllm.v1.attention.backends.utils import AttentionCGSupport + +DeepseekV32IndexerMetadataBuilder.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH diff --git a/vllm_kunlun/v1/sample/spec_decode/eagle.py b/vllm_kunlun/v1/sample/spec_decode/eagle.py index 9cbd59e..9f4e244 100644 --- a/vllm_kunlun/v1/sample/spec_decode/eagle.py +++ b/vllm_kunlun/v1/sample/spec_decode/eagle.py @@ -1,26 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ast from typing import Optional import numpy as np import torch -import torch.nn as nn - -from vllm.attention.layer import Attention from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) -from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, - TreeAttentionMetadataBuilder) -from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata +from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -37,47 +29,65 @@ def propose( target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, + last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) - target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) + target_hidden_states = self.model.combine_hidden_states(target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids assert self.runner is not None - # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ - .build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=0) - if attn_metadata.decode is not None and attn_metadata.decode.spec_num_seq_len is not None: - attn_metadata.decode.spec_num_seq_len = -1 + ubatch_id = dbo_current_ubatch_id() + attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + attn_metadata = attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0 + ) + if ( + hasattr(attn_metadata, "decode") + and attn_metadata.decode is not None + and hasattr(attn_metadata.decode, "spec_num_seq_len") + and attn_metadata.decode.spec_num_seq_len is not None + ): + attn_metadata.decode.spec_num_seq_len = -1 + + if self.draft_indexer_metadata_builder: + draft_indexer_metadata = self.draft_indexer_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=0, + ) + else: + draft_indexer_metadata = None + # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + for layer_name in self.indexer_layer_names: + assert draft_indexer_metadata is not None + per_layer_attn_metadata[layer_name] = draft_indexer_metadata + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens - # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states @@ -94,22 +104,25 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + ): ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method == "deepseek_mtp": + if self.method == "mtp": last_hidden_states = ret_hidden_states hidden_states = self.hidden_states[:num_input_tokens] else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + if self.method == "mtp": + logits = self.model.compute_logits(sample_hidden_states, 0) + else: + logits = self.model.compute_logits(sample_hidden_states, None) positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] @@ -136,21 +149,22 @@ def propose( # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 - common_attn_metadata.query_start_loc = self.arange[:batch_size + 1].to(torch.int32) - common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[:batch_size + 1]).clone().to(torch.int32) - for _ in range(self.num_speculative_tokens - 1): + common_attn_metadata.query_start_loc = self.arange[: batch_size + 1].to(torch.int32) + common_attn_metadata.query_start_loc_cpu = ( + torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone().to(torch.int32) + ) + + attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builder + for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -166,38 +180,39 @@ def propose( exceeds_max_model_len = positions >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + clamped_positions = torch.where(exceeds_max_model_len, 0, positions) # Increment the sequence lengths. common_attn_metadata.seq_lens += 1 common_attn_metadata.seq_lens_cpu += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, - 1) - common_attn_metadata.num_computed_tokens_cpu = \ - common_attn_metadata.seq_lens_cpu - 1 + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + common_attn_metadata.num_computed_tokens_cpu = ( + common_attn_metadata.seq_lens_cpu - 1 + ) # Compute the slot mapping. block_numbers = clamped_positions // self.block_size block_ids = common_attn_metadata.block_table_tensor.gather( - dim=1, index=block_numbers.view(-1, 1)) + dim=1, index=block_numbers.view(-1, 1) + ) block_ids = block_ids.view(-1) common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions % self.block_size) + block_ids * self.block_size + clamped_positions % self.block_size + ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. common_attn_metadata.slot_mapping.masked_fill_( - exceeds_max_model_len, PADDING_SLOT_ID) + exceeds_max_model_len, PADDING_SLOT_ID + ) - attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ - .build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=0) + attn_metadata = attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 + ) for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata + per_layer_attn_metadata[layer_name] = attn_metadata # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions @@ -212,17 +227,23 @@ def propose( input_ids = self.input_ids[:input_batch_size] # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): - last_hidden_states = self.model( + with set_forward_context( + per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size + ): + ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + + hidden_states = hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size]) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -230,83 +251,91 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids -def prepare_next_token_ids_padded(self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: torch.Tensor, - requests: dict[str, CachedRequestState], - gpu_input_batch: InputBatch, - discard_request_indices: torch.Tensor, - num_discarded_requests: int) -> \ - tuple[torch.Tensor, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding. - It calculates the next token ids and the number of valid sampled tokens - for each request, considering the "discarded" requests whose next token - is not sampled and comes from `request.get_token_id()` instead. - It also accounts for the rejected tokens in `sampled_token_ids`. - This function must use device functions to operate on the inputs, and - should not introduce any blocking CPU-GPU synchronization. - """ - # TODO(Ben): Combine this into a custom fused kernel - # Precompute get_token_id for when there is no valid next token - num_reqs = gpu_input_batch.num_reqs - self.backup_next_token_ids.np[:num_reqs] = np.array([ +def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array( + [ requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item()) + common_attn_metadata.seq_lens_cpu[i].item() + ) for i in range(num_reqs) - ]) - self.backup_next_token_ids.copy_to_gpu(num_reqs) + ] + ) + self.backup_next_token_ids.copy_to_gpu(num_reqs) - # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = \ - discard_request_indices[:num_discarded_requests] + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = discard_request_indices[ + :num_discarded_requests + ] - valid_sampled_token_ids_gpu = sampled_token_ids.clone() - # valid_sampled_token_ids_gpu.index_fill_( - # 0, discard_sampled_tokens_req_indices, -1) - # ---- FIX START ---- - # XPU/XMLIR index_fill_ does NOT accept empty index tensor. - if num_discarded_requests > 0: - # make sure index is on same device and is int64 - idx = discard_sampled_tokens_req_indices - if idx.device != valid_sampled_token_ids_gpu.device: - idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True) - if idx.dtype != torch.long: - idx = idx.to(torch.long) - if idx.numel() > 0: - valid_sampled_token_ids_gpu.index_fill_(0, idx, -1) - # ---- FIX END ---- - # Generate a mask for all valid tokens within those requests - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, - dtype=torch.bool) - else: - valid_mask = ( - (valid_sampled_token_ids_gpu != -1) & - (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + # valid_sampled_token_ids_gpu.index_fill_( + # 0, discard_sampled_tokens_req_indices, -1) + # ---- FIX START ---- + # XPU/XMLIR index_fill_ does NOT accept empty index tensor. + if num_discarded_requests > 0: + # make sure index is on same device and is int64 + idx = discard_sampled_tokens_req_indices + if idx.device != valid_sampled_token_ids_gpu.device: + idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True) + if idx.dtype != torch.long: + idx = idx.to(torch.long) + if idx.numel() > 0: + valid_sampled_token_ids_gpu.index_fill_(0, idx, -1) + # ---- FIX END ---- + # Generate a mask for all valid tokens within those requests + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool) + else: + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) - # Count the number of valid tokens in each request - valid_sampled_tokens_count = valid_mask.sum(dim=1) + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) - # Get the rightmost valid index per row - last_valid_indices = valid_sampled_tokens_count - 1 - last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) - # Get last valid token from each row - # (assume undefined state where there is no valid token) - selected_tokens = torch.gather( - valid_sampled_token_ids_gpu, 1, - last_valid_indices_safe.unsqueeze(1)).squeeze(1) + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) - # Use last token if valid, pre-computed backup if not - batch_size = valid_sampled_token_ids_gpu.shape[0] - next_token_ids = torch.where( - last_valid_indices != -1, selected_tokens, - self.backup_next_token_ids.gpu[:batch_size]) + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, + selected_tokens, + self.backup_next_token_ids.gpu[:batch_size], + ) + + return next_token_ids, valid_sampled_tokens_count - return next_token_ids, valid_sampled_tokens_count EagleProposer.propose = propose -EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded \ No newline at end of file +EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded