[CI] Fix FusedMoEConfig and input batch failure to recover CI (#1602)

Make CI happy

1.
c1909e7e8c
changed moeConfig init way
2.
48fb076cbc
changed input batch logic.

This PR address these change to vllm-ascend.

Closes: https://github.com/vllm-project/vllm-ascend/issues/1600

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-03 18:36:17 +08:00
committed by GitHub
parent d96da1f00c
commit a45dfde283
11 changed files with 173 additions and 134 deletions

View File

@@ -33,6 +33,10 @@ from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm_ascend.pool.metadata import PoolingMetadata
from vllm_ascend.utils import vllm_version_is
if not vllm_version_is("0.9.1"):
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
_SAMPLING_EPS = 1e-5
@@ -83,7 +87,9 @@ class InputBatch:
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logits_processing_needs_token_ids: bool = False,
is_spec_decode: bool = False,
):
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
@@ -161,6 +167,9 @@ class InputBatch:
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
# IDs of requests which do not support spec decoding
self.spec_decode_unsupported_reqs: set[str] = set()
self.min_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
@@ -244,6 +253,18 @@ class InputBatch:
self.req_output_token_ids: list[Optional[list[int]]] = []
if not vllm_version_is("0.9.1"):
from vllm.v1.sample.logits_processor import \
init_builtin_logitsprocs
# Define logits processors.
# TODO(andy): logits processor list should be extensible via engine
# constructor argument; for now the list is fixed.
self.logitsprocs = init_builtin_logitsprocs(
pin_memory_available=pin_memory,
max_num_reqs=max_num_reqs + 1,
device=device)
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@@ -293,6 +314,9 @@ class InputBatch:
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
@@ -401,6 +425,7 @@ class InputBatch:
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.spec_decode_unsupported_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
@@ -616,26 +641,48 @@ class InputBatch:
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
min_p=None if self.no_min_p else self.min_p[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
min_tokens=self.min_tokens,
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
)
if vllm_version_is("0.9.1"):
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
min_p=None if self.no_min_p else self.min_p[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]],
self.req_output_token_ids),
min_tokens=self.min_tokens,
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
)
else:
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]],
self.req_output_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
@property
def pooling_metadata(self) -> PoolingMetadata: