[xlite][Bugfix] Support mrope and deepstack features in xlite backend (#7295)
### What this PR does / why we need it?
This PR fixes a bug in Xlite
backend(https://atomgit.com/openeuler/GVirt/issues/3).
This PR adds support for mrope (Mixture-of-RoPE) and deepstack features
in the xlite backend. These features are necessary for running certain
multimodal models that utilize them.
The main changes include:
- Updating `_build_model_config` to parse mrope and deepstack
configurations from the model's `hf_config`.
- Modifying `XliteWrapper.__call__` to handle `deepstack_input_embeds`
and mrope positions during the model forward pass.
- Replacing `ModelAttnMeta` with the newer `AttnMeta` to accommodate the
new metadata fields required by these features.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
online server config:
```
python -m vllm.entrypoints.openai.api_server \
--model /mnt/nvme0n1/models/checkpoint-8200 \
--additional-config='{"xlite_graph_config": {"enabled": true}}' \
--tensor-parallel-size 4 \
--gpu-memory-utilization 0.9 \
--max-num-batched-tokens 8192 \
--max-num-seqs=20 \
--block-size 128 \
--max-model-len 8192 \
--trust-remote-code \
--served-model-name Qwen3-VL-8B \
--host localhost \
--generation-config vllm \
--port 6777
```
test_config:
```
vllm bench serve \
--max-concurrency ${maxconcurrency} \
--num-prompts ${num_prompts} \
--host ${HOST} \
--port ${PORT} \
--model ${MODEL_NAME} \
--dataset-name random \
--backend openai-chat \
--random-input-len 512 \
--random-output-len 512 \
--random-range-ratio 0.2 \
--temperature 0.6 \
--metric-percentiles "50,90,99" \
--tokenizer ${TOKENIZER_PATH} \
--endpoint /v1/chat/completions \
--ignore-eos
```
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: LVYANGGUO <lvyangguo@huawei.com>
Co-authored-by: LVYANGGUO <lvyangguo@huawei.com>
This commit is contained in:
@@ -25,9 +25,9 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from xlite._C import ( # type: ignore[attr-defined]
|
||||
AttnMeta,
|
||||
AttnMHA,
|
||||
Model,
|
||||
ModelAttnMeta,
|
||||
ModelConfig,
|
||||
Runtime,
|
||||
ScoringFuncSoftmax,
|
||||
@@ -89,6 +89,12 @@ class LlamaXliteModel(XliteModel):
|
||||
config.max_batch_size = max_batch_size
|
||||
config.max_seq_len = max_seq_len
|
||||
config.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
vision_config = getattr(vllm_config.model_config.hf_config, "vision_config", None)
|
||||
rope_parameters = getattr(hf_config, "rope_parameters", {})
|
||||
config.deepstack_num_level = len(getattr(vision_config, "deepstack_visual_indexes", []))
|
||||
config.mrope_section = rope_parameters.get("mrope_section", [])
|
||||
config.mrope_interleaved = rope_parameters.get("mrope_interleaved", False)
|
||||
return config
|
||||
|
||||
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, config: ModelConfig) -> Model:
|
||||
@@ -297,15 +303,19 @@ class XliteWrapper:
|
||||
query_lens = query_lens[:batch]
|
||||
cached_lens = seq_lens - query_lens
|
||||
|
||||
xlite_attn_metadata = ModelAttnMeta()
|
||||
num_tokens = forward_context.batch_descriptor.num_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
xlite_attn_metadata = AttnMeta()
|
||||
xlite_attn_metadata.lens = query_lens.tolist()
|
||||
xlite_attn_metadata.cached_lens = cached_lens.tolist()
|
||||
xlite_attn_metadata.is_prefills = [False] * num_decodes + [True] * num_prefills
|
||||
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu().tolist()
|
||||
xlite_attn_metadata.block_tables_cpu = attn_metadata.block_tables.cpu().tolist()
|
||||
if positions.ndim == 2:
|
||||
xlite_attn_metadata.positions = positions[:, : attn_metadata.num_actual_tokens].contiguous()
|
||||
else:
|
||||
xlite_attn_metadata.positions = positions
|
||||
|
||||
# Compatibility between DP and Non-DP scenarios
|
||||
num_tokens = forward_context.batch_descriptor.num_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
h = self.hidden_states[:num_tokens]
|
||||
stream = torch.npu.current_stream().npu_stream
|
||||
if inputs_embeds is None:
|
||||
@@ -313,9 +323,22 @@ class XliteWrapper:
|
||||
self.xlite_rt, input_ids, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream
|
||||
)
|
||||
else:
|
||||
deepstack_input_embeds = getattr(self.runnable, "deepstack_input_embeds", [])
|
||||
xlite_deepstack_input_embeds = [
|
||||
deepstack_input[: inputs_embeds.size(0)] for deepstack_input in deepstack_input_embeds
|
||||
]
|
||||
self.xlite_model.forward_with_inputs_embeds(
|
||||
self.xlite_rt, inputs_embeds, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream
|
||||
self.xlite_rt,
|
||||
inputs_embeds,
|
||||
xlite_attn_metadata,
|
||||
self.kv_caches,
|
||||
self.freq_cis,
|
||||
h,
|
||||
stream,
|
||||
xlite_deepstack_input_embeds,
|
||||
)
|
||||
if xlite_deepstack_input_embeds and hasattr(self.runnable, "_clear_deepstack_input_embeds"):
|
||||
self.runnable._clear_deepstack_input_embeds(inputs_embeds.size(0))
|
||||
return h[:num_actual_tokens]
|
||||
else:
|
||||
return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
|
||||
Reference in New Issue
Block a user