Support v0.10.1 (#2584)
### What this PR does / why we need it?
This patch also supports v0.10.1
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- CI passed
- test 0.10.1: https://github.com/vllm-project/vllm-ascend/pull/2583
- vLLM version: v0.10.1.1
- vLLM main:
321938e9ac
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
This commit is contained in:
@@ -21,7 +21,7 @@ from tests.ut.base import TestBase
|
|||||||
from vllm_ascend.core.scheduler import AscendScheduler
|
from vllm_ascend.core.scheduler import AscendScheduler
|
||||||
from vllm_ascend.utils import vllm_version_is
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
if not vllm_version_is("0.10.1.1"):
|
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||||
from vllm.v1.outputs import DraftTokenIds
|
from vllm.v1.outputs import DraftTokenIds
|
||||||
else:
|
else:
|
||||||
DraftTokenIds = None
|
DraftTokenIds = None
|
||||||
@@ -78,7 +78,7 @@ def make_output(scheduler):
|
|||||||
}
|
}
|
||||||
sampled_token_ids = [[1000]] * len(scheduler.running)
|
sampled_token_ids = [[1000]] * len(scheduler.running)
|
||||||
logprobs = None
|
logprobs = None
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
modelrunner_output = ModelRunnerOutput(
|
modelrunner_output = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_id_to_index,
|
req_id_to_index=req_id_to_index,
|
||||||
@@ -297,7 +297,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
req.status = RequestStatus.RUNNING
|
req.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
@@ -384,7 +384,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
req.status = RequestStatus.RUNNING
|
req.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
@@ -468,7 +468,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
req.status = RequestStatus.RUNNING
|
req.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
@@ -549,7 +549,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
scheduler.requests[requests[0].request_id] = requests[0]
|
scheduler.requests[requests[0].request_id] = requests[0]
|
||||||
scheduler.running.append(requests[0])
|
scheduler.running.append(requests[0])
|
||||||
|
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
@@ -645,7 +645,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
512)
|
512)
|
||||||
|
|
||||||
# Model output of the first request.
|
# Model output of the first request.
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=[requests[0].request_id],
|
req_ids=[requests[0].request_id],
|
||||||
req_id_to_index={requests[0].request_id: 0},
|
req_id_to_index={requests[0].request_id: 0},
|
||||||
@@ -671,7 +671,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
# request is still running.
|
# request is still running.
|
||||||
scheduler.schedule()
|
scheduler.schedule()
|
||||||
# Model output of the second request.
|
# Model output of the second request.
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=[requests[1].request_id],
|
req_ids=[requests[1].request_id],
|
||||||
req_id_to_index={requests[1].request_id: 0},
|
req_id_to_index={requests[1].request_id: 0},
|
||||||
@@ -739,7 +739,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
req_id = requests[i].request_id
|
req_id = requests[i].request_id
|
||||||
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
|
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
|
||||||
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
|
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
@@ -760,7 +760,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
|
|
||||||
engine_core_outputs = scheduler.update_from_output(
|
engine_core_outputs = scheduler.update_from_output(
|
||||||
output, model_runner_output)
|
output, model_runner_output)
|
||||||
if not vllm_version_is("0.10.1.1"):
|
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||||
scheduler.update_draft_token_ids(draft_token_ids)
|
scheduler.update_draft_token_ids(draft_token_ids)
|
||||||
|
|
||||||
for i in range(len(requests)):
|
for i in range(len(requests)):
|
||||||
@@ -797,7 +797,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
else:
|
else:
|
||||||
self.assertNotIn(req_id,
|
self.assertNotIn(req_id,
|
||||||
output.scheduled_spec_decode_tokens)
|
output.scheduled_spec_decode_tokens)
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ def create_model_runner_output(
|
|||||||
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
|
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
|
||||||
finished_recving=finished_recving)
|
finished_recving=finished_recving)
|
||||||
extra_args = {"kv_connector_output": kv_connector_output}
|
extra_args = {"kv_connector_output": kv_connector_output}
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_id_to_index,
|
req_id_to_index=req_id_to_index,
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from vllm.v1.structured_output import StructuredOutputManager
|
|||||||
|
|
||||||
from vllm_ascend.utils import vllm_version_is
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
else:
|
else:
|
||||||
KVCacheBlocks = None
|
KVCacheBlocks = None
|
||||||
@@ -66,7 +66,7 @@ class AscendScheduler(Scheduler):
|
|||||||
scheduled_running_reqs: list[Request] = []
|
scheduled_running_reqs: list[Request] = []
|
||||||
preempted_reqs: list[Request] = []
|
preempted_reqs: list[Request] = []
|
||||||
|
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
req_to_new_block_ids: dict[str, list[int]] = {}
|
req_to_new_block_ids: dict[str, list[int]] = {}
|
||||||
else:
|
else:
|
||||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||||
@@ -227,7 +227,7 @@ class AscendScheduler(Scheduler):
|
|||||||
|
|
||||||
if self.lora_config and request.lora_request:
|
if self.lora_config and request.lora_request:
|
||||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
req_to_new_block_ids[request.request_id] = (
|
req_to_new_block_ids[request.request_id] = (
|
||||||
self.kv_cache_manager.get_block_ids(request.request_id))
|
self.kv_cache_manager.get_block_ids(request.request_id))
|
||||||
else:
|
else:
|
||||||
@@ -320,7 +320,7 @@ class AscendScheduler(Scheduler):
|
|||||||
# Schedule the request.
|
# Schedule the request.
|
||||||
scheduled_running_reqs.append(request)
|
scheduled_running_reqs.append(request)
|
||||||
self.scheduled_req_ids.add(request.request_id)
|
self.scheduled_req_ids.add(request.request_id)
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
req_to_new_block_ids[request.request_id] = (
|
req_to_new_block_ids[request.request_id] = (
|
||||||
new_blocks.get_block_ids())
|
new_blocks.get_block_ids())
|
||||||
else:
|
else:
|
||||||
@@ -362,7 +362,7 @@ class AscendScheduler(Scheduler):
|
|||||||
any_request, len(self.running)))
|
any_request, len(self.running)))
|
||||||
|
|
||||||
# Construct the scheduler output.
|
# Construct the scheduler output.
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
new_reqs_data = [
|
new_reqs_data = [
|
||||||
NewRequestData.from_request(
|
NewRequestData.from_request(
|
||||||
req, req_to_new_block_ids[req.request_id])
|
req, req_to_new_block_ids[req.request_id])
|
||||||
@@ -385,7 +385,7 @@ class AscendScheduler(Scheduler):
|
|||||||
req_to_new_blocks)
|
req_to_new_blocks)
|
||||||
scheduled_cached_reqs = cached_reqs_data
|
scheduled_cached_reqs = cached_reqs_data
|
||||||
|
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs_data,
|
scheduled_new_reqs=new_reqs_data,
|
||||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
self.num_redundant_experts = parallel_config.num_redundant_experts
|
self.num_redundant_experts = parallel_config.num_redundant_experts
|
||||||
else:
|
else:
|
||||||
eplb_config = parallel_config.eplb_config
|
eplb_config = parallel_config.eplb_config
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from vllm.v1.sample.sampler import Sampler
|
|||||||
|
|
||||||
from vllm_ascend.utils import is_310p, vllm_version_is
|
from vllm_ascend.utils import is_310p, vllm_version_is
|
||||||
|
|
||||||
if not vllm_version_is("0.10.1.1"):
|
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||||
from vllm.config import LogprobsMode
|
from vllm.config import LogprobsMode
|
||||||
DEFAULT_LOGPROBS_MODE = LogprobsMode.RAW_LOGPROBS
|
DEFAULT_LOGPROBS_MODE = LogprobsMode.RAW_LOGPROBS
|
||||||
else:
|
else:
|
||||||
@@ -68,7 +68,7 @@ class AscendTopKTopPSampler(TopKTopPSampler):
|
|||||||
def forward_native(self, logits, generators, k, p):
|
def forward_native(self, logits, generators, k, p):
|
||||||
"""Override pytorch native implementation to torch_npu"""
|
"""Override pytorch native implementation to torch_npu"""
|
||||||
logits = self._apply_top_k_top_p(logits, k, p)
|
logits = self._apply_top_k_top_p(logits, k, p)
|
||||||
if not vllm_version_is("0.10.1.1"):
|
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||||
|
|
||||||
logits_to_return = None
|
logits_to_return = None
|
||||||
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
|
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
|
||||||
@@ -79,7 +79,7 @@ class AscendTopKTopPSampler(TopKTopPSampler):
|
|||||||
|
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
output = None
|
output = None
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
output = random_sample(probs, generators)
|
output = random_sample(probs, generators)
|
||||||
else:
|
else:
|
||||||
output = (random_sample(probs, generators), logits_to_return)
|
output = (random_sample(probs, generators), logits_to_return)
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
|||||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
if not vllm_version_is("0.10.1.1"):
|
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||||
from vllm.v1.outputs import DraftTokenIds
|
from vllm.v1.outputs import DraftTokenIds
|
||||||
else:
|
else:
|
||||||
DraftTokenIds = None
|
DraftTokenIds = None
|
||||||
@@ -384,7 +384,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Remove finished requests from the cached states.
|
# Remove finished requests from the cached states.
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
self.requests.pop(req_id, None)
|
self.requests.pop(req_id, None)
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
self.encoder_cache.pop(req_id, None)
|
self.encoder_cache.pop(req_id, None)
|
||||||
# Remove the finished requests from the persistent batch.
|
# Remove the finished requests from the persistent batch.
|
||||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||||
@@ -394,7 +394,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# and handling the second as a new request.
|
# and handling the second as a new request.
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
self.input_batch.remove_request(req_id)
|
self.input_batch.remove_request(req_id)
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
# Free the cached encoder outputs.
|
# Free the cached encoder outputs.
|
||||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||||
encoder_outputs = self.encoder_cache.get(req_id)
|
encoder_outputs = self.encoder_cache.get(req_id)
|
||||||
@@ -455,9 +455,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
lora_request=new_req_data.lora_request,
|
lora_request=new_req_data.lora_request,
|
||||||
**({
|
**({
|
||||||
"mm_hashes": new_req_data.mm_hashes
|
"mm_hashes": new_req_data.mm_hashes
|
||||||
} if not vllm_version_is("0.10.1.1") else {
|
} if not (vllm_version_is("0.10.1.1")
|
||||||
"mm_hashes": None
|
or vllm_version_is("0.10.1")) else {
|
||||||
}),
|
"mm_hashes": None
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
@@ -893,13 +894,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Batch the multi-modal inputs.
|
# Batch the multi-modal inputs.
|
||||||
mm_kwargs = list[MultiModalKwargsItem]()
|
mm_kwargs = list[MultiModalKwargsItem]()
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||||
else:
|
else:
|
||||||
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
||||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
for mm_input_id in encoder_input_ids:
|
for mm_input_id in encoder_input_ids:
|
||||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||||
req_ids_pos.append((req_id, mm_input_id,
|
req_ids_pos.append((req_id, mm_input_id,
|
||||||
@@ -942,7 +943,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
for output in curr_group_outputs:
|
for output in curr_group_outputs:
|
||||||
encoder_outputs.append(output)
|
encoder_outputs.append(output)
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
# Cache the encoder outputs.
|
# Cache the encoder outputs.
|
||||||
for (req_id, input_id, pos_info), output in zip(
|
for (req_id, input_id, pos_info), output in zip(
|
||||||
req_ids_pos,
|
req_ids_pos,
|
||||||
@@ -974,7 +975,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
num_computed_tokens = req_state.num_computed_tokens
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
mm_positions = req_state.mm_positions
|
mm_positions = req_state.mm_positions
|
||||||
if not vllm_version_is("0.10.1.1"):
|
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||||
mm_hashes = req_state.mm_hashes
|
mm_hashes = req_state.mm_hashes
|
||||||
for i, pos_info in enumerate(mm_positions):
|
for i, pos_info in enumerate(mm_positions):
|
||||||
start_pos = pos_info.offset
|
start_pos = pos_info.offset
|
||||||
@@ -993,7 +994,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
end_idx = min(
|
end_idx = min(
|
||||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||||
num_encoder_tokens)
|
num_encoder_tokens)
|
||||||
@@ -1719,7 +1720,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logits = None
|
logits = None
|
||||||
else:
|
else:
|
||||||
if self.input_batch.pooling_params:
|
if self.input_batch.pooling_params:
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is(
|
||||||
|
"0.10.1"):
|
||||||
return self._pool_v010(
|
return self._pool_v010(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
scheduler_output.total_num_scheduled_tokens,
|
scheduler_output.total_num_scheduled_tokens,
|
||||||
@@ -1867,7 +1869,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
extra_args = ({"kv_connector_output": kv_connector_output})
|
extra_args = ({"kv_connector_output": kv_connector_output})
|
||||||
|
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=self.input_batch.req_ids,
|
req_ids=self.input_batch.req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
@@ -2191,7 +2193,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dummy_pooling_params = PoolingParams(task=task)
|
dummy_pooling_params = PoolingParams(task=task)
|
||||||
to_update = model.pooler.get_pooling_updates(task)
|
to_update = model.pooler.get_pooling_updates(task)
|
||||||
to_update.apply(dummy_pooling_params)
|
to_update.apply(dummy_pooling_params)
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
dummy_prompt_lens = torch.tensor(
|
dummy_prompt_lens = torch.tensor(
|
||||||
[h.shape[0] for h in hidden_states_list],
|
[h.shape[0] for h in hidden_states_list],
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|||||||
@@ -726,7 +726,7 @@ class InputBatch:
|
|||||||
pooling_params = [
|
pooling_params = [
|
||||||
self.pooling_params[req_id] for req_id in self.req_ids
|
self.pooling_params[req_id] for req_id in self.req_ids
|
||||||
]
|
]
|
||||||
if vllm_version_is("0.10.1.1"):
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
return PoolingMetadata(
|
return PoolingMetadata(
|
||||||
prompt_lens=torch.from_numpy(
|
prompt_lens=torch.from_numpy(
|
||||||
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from vllm_ascend.utils import (init_ascend_soc_version,
|
|||||||
try_register_lib, vllm_version_is)
|
try_register_lib, vllm_version_is)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
if not vllm_version_is("0.10.1.1"):
|
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||||
from vllm.v1.outputs import DraftTokenIds
|
from vllm.v1.outputs import DraftTokenIds
|
||||||
else:
|
else:
|
||||||
DraftTokenIds = None
|
DraftTokenIds = None
|
||||||
|
|||||||
Reference in New Issue
Block a user