[CI] Upgrade vLLM to 20250920 (c60e613) and address config break (#3067)

### What this PR does / why we need it?
Bump main to
c60e6137f0

- Updated imports in `vllm.config` to
`vllm.config.model`(aed16879a9)
https://github.com/vllm-project/vllm/pull/25252

- Refactored `vllm_ascend/sample/sampler.py` to use string values for
`logprobs_mode` instead of the `LogprobsMode` enum, simplifying logprobs
mode handling and improving compatibility with recent vLLM changes
(aed16879a9)
https://github.com/vllm-project/vllm/pull/25252

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed


- vLLM version: v0.10.2
- vLLM main:
6d8246aaff

---------

Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
This commit is contained in:
Yikun Jiang
2025-09-21 09:49:17 +08:00
committed by GitHub
parent 12bcbd02bb
commit b8b68b3dfe
5 changed files with 32 additions and 14 deletions

View File

@@ -1,12 +1,15 @@
import torch
import torch_npu
from vllm.config import LogprobsMode
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
from vllm.v1.sample.sampler import Sampler
from vllm_ascend.utils import is_310p
from vllm_ascend.utils import is_310p, vllm_version_is
DEFAULT_LOGPROBS_MODE = LogprobsMode.RAW_LOGPROBS
if vllm_version_is("0.10.2"):
from vllm.config import LogprobsMode
DEFAULT_LOGPROBS_MODE = LogprobsMode.RAW_LOGPROBS
else:
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
class AscendSampler(Sampler):
@@ -65,10 +68,18 @@ class AscendTopKTopPSampler(TopKTopPSampler):
"""Override pytorch native implementation to torch_npu"""
logits = self._apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
logits_to_return = logits
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
if vllm_version_is("0.10.2"):
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
logits_to_return = logits
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
logits_to_return = logits.log_softmax(dim=-1,
dtype=torch.float32)
else:
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1,
dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), logits_to_return