[Refactor]Refactor sampler (#2050)
Refactor Sampler implementation from patch way to inherit from vLLM
Sampler interface.
Next step: Make the op `TopKTopPSampler` in vLLM support custom ops
register mechanism
- vLLM version: v0.10.0
- vLLM main:
61a6905ab0
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
62
vllm_ascend/sample/sampler.py
Normal file
62
vllm_ascend/sample/sampler.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
|
||||
class AscendSampler(Sampler):
|
||||
|
||||
def __init__(self, logprobs_mode="raw_logprobs"):
|
||||
# TODO: support logprobs_mode in vllm-ascend
|
||||
super().__init__(logprobs_mode=logprobs_mode)
|
||||
self.topk_topp_sampler = AscendTopKTopPSampler()
|
||||
|
||||
|
||||
class AscendTopKTopPSampler(TopKTopPSampler):
|
||||
|
||||
def _apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if p is not None and k is not None:
|
||||
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
||||
return torch_npu.npu_top_k_top_p(logits, p, k)
|
||||
|
||||
if p is None and k is None:
|
||||
return logits
|
||||
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(
|
||||
torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
def forward_native(self, logits, generators, k, p):
|
||||
"""Override pytorch native implementation to torch_npu"""
|
||||
logits = self._apply_top_k_top_p(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
Reference in New Issue
Block a user