diff --git a/vllm_ascend/xlite/xlite.py b/vllm_ascend/xlite/xlite.py index cd7423c9..ac3b1f9c 100644 --- a/vllm_ascend/xlite/xlite.py +++ b/vllm_ascend/xlite/xlite.py @@ -25,9 +25,9 @@ from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.sequence import IntermediateTensors from xlite._C import ( # type: ignore[attr-defined] + AttnMeta, AttnMHA, Model, - ModelAttnMeta, ModelConfig, Runtime, ScoringFuncSoftmax, @@ -89,6 +89,12 @@ class LlamaXliteModel(XliteModel): config.max_batch_size = max_batch_size config.max_seq_len = max_seq_len config.block_size = vllm_config.cache_config.block_size + + vision_config = getattr(vllm_config.model_config.hf_config, "vision_config", None) + rope_parameters = getattr(hf_config, "rope_parameters", {}) + config.deepstack_num_level = len(getattr(vision_config, "deepstack_visual_indexes", [])) + config.mrope_section = rope_parameters.get("mrope_section", []) + config.mrope_interleaved = rope_parameters.get("mrope_interleaved", False) return config def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, config: ModelConfig) -> Model: @@ -297,15 +303,19 @@ class XliteWrapper: query_lens = query_lens[:batch] cached_lens = seq_lens - query_lens - xlite_attn_metadata = ModelAttnMeta() + num_tokens = forward_context.batch_descriptor.num_tokens + num_actual_tokens = attn_metadata.num_actual_tokens + xlite_attn_metadata = AttnMeta() xlite_attn_metadata.lens = query_lens.tolist() xlite_attn_metadata.cached_lens = cached_lens.tolist() xlite_attn_metadata.is_prefills = [False] * num_decodes + [True] * num_prefills - xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu().tolist() + xlite_attn_metadata.block_tables_cpu = attn_metadata.block_tables.cpu().tolist() + if positions.ndim == 2: + xlite_attn_metadata.positions = positions[:, : attn_metadata.num_actual_tokens].contiguous() + else: + xlite_attn_metadata.positions = positions # Compatibility between DP and Non-DP scenarios - num_tokens = forward_context.batch_descriptor.num_tokens - num_actual_tokens = attn_metadata.num_actual_tokens h = self.hidden_states[:num_tokens] stream = torch.npu.current_stream().npu_stream if inputs_embeds is None: @@ -313,9 +323,22 @@ class XliteWrapper: self.xlite_rt, input_ids, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream ) else: + deepstack_input_embeds = getattr(self.runnable, "deepstack_input_embeds", []) + xlite_deepstack_input_embeds = [ + deepstack_input[: inputs_embeds.size(0)] for deepstack_input in deepstack_input_embeds + ] self.xlite_model.forward_with_inputs_embeds( - self.xlite_rt, inputs_embeds, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream + self.xlite_rt, + inputs_embeds, + xlite_attn_metadata, + self.kv_caches, + self.freq_cis, + h, + stream, + xlite_deepstack_input_embeds, ) + if xlite_deepstack_input_embeds and hasattr(self.runnable, "_clear_deepstack_input_embeds"): + self.runnable._clear_deepstack_input_embeds(inputs_embeds.size(0)) return h[:num_actual_tokens] else: return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)