diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 2eae492..7c04feb 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -72,7 +72,7 @@ class EagleProposer(Proposer): dtype=torch.int32) attn_mask_len = self.vllm_config.model_config.max_model_len self.attn_mask_builder = AttentionMaskBuilder( - attn_mask_len, self.vllm_config.model_config.dtype) + attn_mask_len, self.vllm_config.model_config.dtype, device=device) def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set( @@ -424,9 +424,7 @@ class EagleProposer(Proposer): query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() - attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, target_positions, self.vllm_config.model_config.dtype, - self.device) + attn_mask = self.runner.attn_mask common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=cu_num_tokens.to(device), @@ -506,9 +504,15 @@ class EagleProposer(Proposer): attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] + attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[ + 1:].tolist() + attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size + attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens query_lens.fill_(1) attn_metadata.query_lens = query_lens + attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)] + attn_metadata.seq_lens_list = seq_lens.tolist() attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill for now_speculative in range( self.vllm_config.speculative_config.num_speculative_tokens - @@ -535,6 +539,9 @@ class EagleProposer(Proposer): # TODO: Increment the sequence lengths. attn_metadata.seq_lens += 1 + attn_metadata.seq_lens_list = [ + _ + 1 for _ in attn_metadata.seq_lens_list + ] # TODO: Consider max model length. # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, # self.max_model_len) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d19453e..0d128eb 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -61,6 +61,7 @@ _IS_VL_MODEL = None _ENABLE_SP = None _HAS_LAYER_IDX = None _ENABLE_NZ = None +_IS_EAGLE_MODE = None def is_310p(): @@ -73,14 +74,20 @@ def is_310p(): def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8, vllm_config: Optional[VllmConfig] = None) -> bool: - global _ENABLE_NZ + global _ENABLE_NZ, _IS_EAGLE_MODE if _ENABLE_NZ is None: if not vllm_config: raise ValueError( "vllm_config must be provided when _ENABLE_NZ is None") _ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next" + + _IS_EAGLE_MODE = ( + vllm_config.speculative_config is not None and + getattr(vllm_config.speculative_config, 'method', None) in ("eagle", "eagle3") + ) + if dtype in [torch.float16, torch.bfloat16]: - return False + return _ENABLE_NZ if _IS_EAGLE_MODE else False return _ENABLE_NZ