diff --git a/vllm_ascend/_310p/model_runner_310p.py b/vllm_ascend/_310p/model_runner_310p.py index 5c09c0e7..b16c53e0 100644 --- a/vllm_ascend/_310p/model_runner_310p.py +++ b/vllm_ascend/_310p/model_runner_310p.py @@ -35,7 +35,9 @@ from vllm.v1.kv_cache_interface import ( MambaSpec, UniformTypeKVCacheSpecs, ) +from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm_ascend._310p.sample.sampler import AscendSampler310 from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -52,6 +54,9 @@ class NPUModelRunner310(NPUModelRunner): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._acl_format = ACL_FORMAT_FRACTAL_NZ + self.sampler = AscendSampler310() + if getattr(self, "rejection_sampler", None) is not None: + self.rejection_sampler = RejectionSampler(self.sampler) if self.speculative_config is not None and self.speculative_config.method == "ngram": # 310P ngram requires decode-only graph shapes to be built with q_len=1. # Keep dispatcher's internal query_len in sync to avoid key-init assert. diff --git a/vllm_ascend/_310p/sample/__init__.py b/vllm_ascend/_310p/sample/__init__.py new file mode 100644 index 00000000..f3766c05 --- /dev/null +++ b/vllm_ascend/_310p/sample/__init__.py @@ -0,0 +1,3 @@ +from vllm_ascend._310p.sample.sampler import AscendSampler310 + +__all__ = ["AscendSampler310"] diff --git a/vllm_ascend/_310p/sample/sampler.py b/vllm_ascend/_310p/sample/sampler.py new file mode 100644 index 00000000..7e741dda --- /dev/null +++ b/vllm_ascend/_310p/sample/sampler.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant + +from vllm_ascend.sample.sampler import ( + DEFAULT_LOGPROBS_MODE, + AscendSampler, + AscendTopKTopPSampler, +) +from vllm_ascend.utils import global_stream, npu_stream_switch + + +def _random_sample_310p( + probs: torch.Tensor, + generators: dict[int, torch.Generator], +) -> torch.Tensor: + """310P-specific random sampling with CPU exponential generation for q.""" + with npu_stream_switch(global_stream()): + q = torch.empty_like(probs) + q = q.cpu() + if len(generators) != q.shape[0]: + q.exponential_() + if generators: + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + q = q.npu() + torch.npu.current_stream().wait_stream(global_stream()) + return probs.div_(q).argmax(dim=-1).view(-1) + + +class AscendTopKTopPSampler310(AscendTopKTopPSampler): + def forward_native(self, logits, generators, k, p): + if vllm_is_batch_invariant(): + return super().forward_native(logits, generators, k, p) + + logits = self.apply_top_k_top_p(logits, k, p) + logits_to_return = None + if self.logprobs_mode == "processed_logits": + logits_to_return = logits + elif self.logprobs_mode == "processed_logprobs": + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) + + probs = logits.softmax(dim=-1, dtype=torch.float32) + return _random_sample_310p(probs, generators), logits_to_return + + +class AscendSampler310(AscendSampler): + def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE): + super().__init__(logprobs_mode=logprobs_mode) + self.topk_topp_sampler = AscendTopKTopPSampler310(logprobs_mode=logprobs_mode)