[ModelRunner] apply_grammer uses vllm function (#4974)
### What this PR does / why we need it?
this pr removes apply_gramme in npu_model_runner. we change logits to
cpu, and do the same thing with gpu_model_runner.
it may change the performance, we will change it after torch.compile is
supported with npu inductor
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
This commit is contained in:
@@ -77,6 +77,7 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler
|
|||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||||
|
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||||
from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput,
|
from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput,
|
||||||
GPUModelRunner)
|
GPUModelRunner)
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
||||||
@@ -1626,70 +1627,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def apply_grammar_bitmask(
|
|
||||||
self,
|
|
||||||
scheduler_output: "SchedulerOutput",
|
|
||||||
grammar_output: "GrammarOutput",
|
|
||||||
logits: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
grammar_bitmask = grammar_output.grammar_bitmask
|
|
||||||
|
|
||||||
# We receive the structured output bitmask from the scheduler,
|
|
||||||
# compacted to contain bitmasks only for structured output requests.
|
|
||||||
# The order of the requests in the bitmask is not guaranteed to be the
|
|
||||||
# same as the order of the requests in the gpu runner's batch. We need
|
|
||||||
# to sort the bitmask to match the order of the requests used here.
|
|
||||||
|
|
||||||
# Get the batch indices of the structured output requests.
|
|
||||||
# Keep track of the number of speculative tokens scheduled for every
|
|
||||||
# request in the batch, as the logit indices are offset by this amount.
|
|
||||||
struct_out_req_batch_indices: dict[str, int] = {}
|
|
||||||
cumulative_offset = 0
|
|
||||||
seq = sorted(self.input_batch.req_id_to_index.items(),
|
|
||||||
key=lambda x: x[1])
|
|
||||||
for req_id, batch_index in seq:
|
|
||||||
logit_index = batch_index + cumulative_offset
|
|
||||||
cumulative_offset += len(
|
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|
||||||
if req_id in grammar_output.structured_output_request_ids:
|
|
||||||
struct_out_req_batch_indices[req_id] = logit_index
|
|
||||||
|
|
||||||
out_indices = []
|
|
||||||
|
|
||||||
# Reorder the bitmask to match the order of the requests in the batch.
|
|
||||||
sorted_bitmask = np.zeros_like(grammar_bitmask,
|
|
||||||
shape=(logits.shape[0],
|
|
||||||
grammar_bitmask.shape[1]))
|
|
||||||
cumulative_index = 0
|
|
||||||
for req_id in grammar_output.structured_output_request_ids:
|
|
||||||
num_spec_tokens = len(
|
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|
||||||
if req_id in struct_out_req_batch_indices:
|
|
||||||
logit_index = struct_out_req_batch_indices[req_id]
|
|
||||||
for i in range(1 + num_spec_tokens):
|
|
||||||
sorted_bitmask[logit_index +
|
|
||||||
i] = grammar_bitmask[cumulative_index + i]
|
|
||||||
out_indices.append(logit_index + i)
|
|
||||||
cumulative_index += 1 + num_spec_tokens
|
|
||||||
grammar_bitmask = sorted_bitmask
|
|
||||||
|
|
||||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
|
||||||
# so we receive it in that format.
|
|
||||||
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
|
||||||
|
|
||||||
# NOTE:
|
|
||||||
# 1. XGrammar bitmask applying only supports CPU and GPU.
|
|
||||||
# 2. The logits and bitmask should be on the same device.
|
|
||||||
# 3. XGrammar logits on CPU only supports float32 dtype.
|
|
||||||
logits_dtype = logits.dtype
|
|
||||||
logits = logits.to("cpu").float()
|
|
||||||
xgr.apply_token_bitmask_inplace(
|
|
||||||
logits,
|
|
||||||
grammar_bitmask,
|
|
||||||
indices=out_indices,
|
|
||||||
)
|
|
||||||
return logits.to(self.device).to(logits_dtype)
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def sample_tokens(
|
def sample_tokens(
|
||||||
self, grammar_output: "GrammarOutput | None"
|
self, grammar_output: "GrammarOutput | None"
|
||||||
@@ -1715,8 +1652,13 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
# Apply structured output bitmasks if present.
|
# Apply structured output bitmasks if present.
|
||||||
if grammar_output is not None:
|
if grammar_output is not None:
|
||||||
logits = self.apply_grammar_bitmask(scheduler_output,
|
# here we are different from gpu_model_runner,
|
||||||
grammar_output, logits)
|
# the apply_grammar_bitmask uses torch.compile to optimize this,ascend does not support it now
|
||||||
|
logits_dtype = logits.dtype
|
||||||
|
logits = logits.to("cpu").float()
|
||||||
|
apply_grammar_bitmask(scheduler_output, grammar_output,
|
||||||
|
self.input_batch, logits)
|
||||||
|
logits = logits.to(self.device).to(logits_dtype)
|
||||||
|
|
||||||
with ProfileExecuteDuration().capture_async("Sample"):
|
with ProfileExecuteDuration().capture_async("Sample"):
|
||||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||||
|
|||||||
Reference in New Issue
Block a user