[CI] Fix broken CI (#6599)
Revert4fb3d5e1b2it breaks E2E Test - vLLM version: v0.15.0 - vLLM main:d7e17aaacd
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer as VllmNgramProposer
|
||||
from vllm.v1.spec_decode.ngram_proposer import \
|
||||
NgramProposer as VllmNgramProposer
|
||||
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
|
||||
|
||||
class NgramProposer(VllmNgramProposer, Proposer):
|
||||
|
||||
def __init__(self, vllm_config, device, runner):
|
||||
super().__init__(vllm_config)
|
||||
self.name = SpecDcodeType.NGRAM
|
||||
@@ -17,31 +19,27 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens,
|
||||
with_prefill=None,
|
||||
in_graph_capturing=None,
|
||||
num_reqs=None,
|
||||
num_tokens_across_dp=None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None,
|
||||
is_profile=False,
|
||||
):
|
||||
def dummy_run(self,
|
||||
num_tokens,
|
||||
with_prefill=None,
|
||||
in_graph_capturing=None,
|
||||
num_reqs=None,
|
||||
num_tokens_across_dp=None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None,
|
||||
is_profile=False):
|
||||
pass
|
||||
|
||||
def generate_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata=None,
|
||||
scheduler_output=None,
|
||||
spec_decode_metadata=None,
|
||||
positions=None,
|
||||
num_scheduled_tokens=None,
|
||||
hidden_states=None,
|
||||
aux_hidden_states=None,
|
||||
) -> list[list[int]]:
|
||||
def generate_token_ids(self,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata=None,
|
||||
scheduler_output=None,
|
||||
spec_decode_metadata=None,
|
||||
positions=None,
|
||||
num_scheduled_tokens=None,
|
||||
hidden_states=None,
|
||||
aux_hidden_states=None) -> list[list[int]]:
|
||||
valid_ngram_requests = []
|
||||
for i, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
@@ -59,7 +57,8 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
||||
|
||||
start_idx = self.runner.input_batch.num_tokens_no_spec[i]
|
||||
end_idx = start_idx + num_sampled_ids
|
||||
self.runner.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||
self.runner.input_batch.token_ids_cpu[
|
||||
i, start_idx:end_idx] = sampled_ids
|
||||
|
||||
valid_ngram_requests.append(i)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user