[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #3) (#5978)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/mla_v1.py` |
| `vllm_ascend/attention/sfa_v1.py` |
| `vllm_ascend/core/recompute_scheduler.py` |
| `vllm_ascend/core/scheduler_dynamic_batch.py` |
| `vllm_ascend/distributed/device_communicators/npu_communicator.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
Co-authored-by: Soren <user@SorendeMac-mini.local>
This commit is contained in:
SILONG ZENG
2026-01-24 22:10:18 +08:00
committed by GitHub
parent 4e53c1d900
commit 7faa6878a6
9 changed files with 953 additions and 1148 deletions

View File

@@ -51,11 +51,6 @@ line-length = 120
# Folder to be modified # Folder to be modified
exclude = [ exclude = [
"tests/**", "tests/**",
# (3)
"vllm_ascend/attention/*.py",
"vllm_ascend/core/*.py",
"vllm_ascend/distributed/device_communicators/**",
"vllm_ascend/distributed/utils.py",
# (5) # (5)
"vllm_ascend/distributed/kv_transfer/kv_pool/**", "vllm_ascend/distributed/kv_transfer/kv_pool/**",
"vllm_ascend/distributed/kv_transfer/utils/**", "vllm_ascend/distributed/kv_transfer/utils/**",

View File

@@ -394,7 +394,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_kv_no_split = kv_no_split[:num_actual_tokens] prefill_kv_no_split = kv_no_split[:num_actual_tokens]
kv_c, k_pe = prefill_kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c, k_pe = prefill_kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) # type: ignore[misc]
assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view([num_actual_tokens, self.num_kv_heads, -1]) kv_c_normed = kv_c_normed.view([num_actual_tokens, self.num_kv_heads, -1])
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -21,26 +21,21 @@ from __future__ import annotations
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Type, Union
from vllm._bc_linter import bc_linter_include from vllm._bc_linter import bc_linter_include
from vllm.config import SchedulerConfig, VllmConfig from vllm.config import SchedulerConfig, VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.kv_events import KVEventBatch from vllm.distributed.kv_events import KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \
KVConnectorStats
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import (SchedulingPolicy, from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
create_request_queue)
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
EngineCoreOutputs, FinishReason)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
@@ -51,26 +46,22 @@ logger = init_logger(__name__)
@dataclass @dataclass
class RecomputeSchedulerConfig(SchedulerConfig): class RecomputeSchedulerConfig(SchedulerConfig):
scheduler_cls: Union[str, Type[object]] = ( scheduler_cls: str | type[object] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
@classmethod @classmethod
def initialize_from_config(cls, vllm_config: VllmConfig): def initialize_from_config(cls, vllm_config: VllmConfig):
vllm_scheduler_config = vllm_config.scheduler_config vllm_scheduler_config = vllm_config.scheduler_config
scheduler_config = { scheduler_config = {
field.name: getattr(vllm_scheduler_config, field.name) field.name: getattr(vllm_scheduler_config, field.name)
for field in fields(vllm_scheduler_config) if field.init for field in fields(vllm_scheduler_config)
if field.init
} }
if vllm_scheduler_config.async_scheduling: if vllm_scheduler_config.async_scheduling:
scheduler_config["scheduler_cls"] = ( scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler"
"vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler")
else: else:
scheduler_config["scheduler_cls"] = ( scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler") scheduler_config["max_model_len"] = vllm_config.model_config.max_model_len
scheduler_config[ scheduler_config["is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder
"max_model_len"] = vllm_config.model_config.max_model_len
scheduler_config[
"is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder
return cls(**scheduler_config) return cls(**scheduler_config)
@@ -125,33 +116,32 @@ class RecomputeScheduler(Scheduler):
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
if (request.num_output_placeholders > 0 if (
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1). request.num_output_placeholders > 0
# Since output placeholders are also included in the computed tokens # This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
# count, we subtract (num_output_placeholders - 1) to remove any draft # Since output placeholders are also included in the computed tokens
# tokens, so that we can be sure no further steps are needed even if # count, we subtract (num_output_placeholders - 1) to remove any draft
# they are all rejected. # tokens, so that we can be sure no further steps are needed even if
and request.num_computed_tokens + 2 - # they are all rejected.
request.num_output_placeholders and request.num_computed_tokens + 2 - request.num_output_placeholders
>= request.num_prompt_tokens + request.max_tokens): >= request.num_prompt_tokens + request.max_tokens
):
# Async scheduling: Avoid scheduling an extra step when we are sure that # Async scheduling: Avoid scheduling an extra step when we are sure that
# the previous step has reached request.max_tokens. We don't schedule # the previous step has reached request.max_tokens. We don't schedule
# partial draft tokens since this prevents uniform decode optimizations. # partial draft tokens since this prevents uniform decode optimizations.
req_index += 1 req_index += 1
continue continue
num_new_tokens = (request.num_tokens_with_spec + num_new_tokens = (
request.num_output_placeholders - request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
request.num_computed_tokens) )
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len. # Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding. # This is necessary when using spec decoding.
num_new_tokens = min( num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs. # Schedule encoder inputs.
encoder_inputs_to_schedule = None encoder_inputs_to_schedule = None
@@ -209,9 +199,10 @@ class RecomputeScheduler(Scheduler):
recomputed_req = self.running.pop() recomputed_req = self.running.pop()
self.kv_cache_manager.free(recomputed_req) self.kv_cache_manager.free(recomputed_req)
recomputed_reqs.append( recomputed_reqs.append(
RecomputeReqInfo(recomputed_req.request_id, RecomputeReqInfo(
recomputed_req.output_token_ids, recomputed_req.request_id, recomputed_req.output_token_ids, recomputed_req.client_index
recomputed_req.client_index)) )
)
if recomputed_req == request: if recomputed_req == request:
break break
else: else:
@@ -223,28 +214,23 @@ class RecomputeScheduler(Scheduler):
self.running.remove(preempted_req) self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs: if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req) scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[ token_budget += num_scheduled_tokens[preempted_req.request_id]
preempted_req.request_id]
req_to_new_blocks.pop(preempted_req.request_id) req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop( num_scheduled_tokens.pop(preempted_req.request_id)
preempted_req.request_id) scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
scheduled_spec_decode_tokens.pop( preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
preempted_req.request_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(
preempted_req.request_id, None)
if preempted_encoder_inputs: if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted # Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step. # request had encoder inputs scheduled in this step.
num_embeds_to_restore = sum( num_embeds_to_restore = sum(
preempted_req.get_num_encoder_embeds(i) preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs
for i in preempted_encoder_inputs) )
encoder_compute_budget += num_embeds_to_restore encoder_compute_budget += num_embeds_to_restore
req_index -= 1 req_index -= 1
else: else:
preempted_req = self.running.pop() preempted_req = self.running.pop()
self._preempt_request(preempted_req, self._preempt_request(preempted_req, scheduled_timestamp)
scheduled_timestamp)
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)
if preempted_req == request: if preempted_req == request:
# No more request to preempt. Cannot schedule this request. # No more request to preempt. Cannot schedule this request.
@@ -263,23 +249,20 @@ class RecomputeScheduler(Scheduler):
# Speculative decode related. # Speculative decode related.
if request.spec_token_ids: if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens + num_scheduled_spec_tokens = (
request.num_computed_tokens - num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
request.num_tokens - )
request.num_output_placeholders)
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens. # Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:] del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = ( scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
request.spec_token_ids)
# New spec tokens will be set in `update_draft_token_ids` before the # New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable. # next step when applicable.
request.spec_token_ids = [] request.spec_token_ids = []
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = ( scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
encoder_inputs_to_schedule)
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -294,8 +277,10 @@ class RecomputeScheduler(Scheduler):
scheduled_loras: set[int] = set() scheduled_loras: set[int] = set()
if self.lora_config: if self.lora_config:
scheduled_loras = set( scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs req.lora_request.lora_int_id
if req.lora_request and req.lora_request.lora_int_id > 0) for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0
)
assert len(scheduled_loras) <= self.lora_config.max_loras assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be # Use a temporary RequestQueue to collect requests that need to be
@@ -337,9 +322,14 @@ class RecomputeScheduler(Scheduler):
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if (self.lora_config and request.lora_request and if (
(len(scheduled_loras) == self.lora_config.max_loras and self.lora_config
request.lora_request.lora_int_id not in scheduled_loras)): and request.lora_request
and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras
)
):
# Scheduling would exceed max_loras, skip. # Scheduling would exceed max_loras, skip.
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
@@ -351,14 +341,15 @@ class RecomputeScheduler(Scheduler):
# Get already-cached tokens. # Get already-cached tokens.
if request.num_computed_tokens == 0: if request.num_computed_tokens == 0:
# Get locally-cached tokens. # Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = ( new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
self.kv_cache_manager.get_computed_blocks(request)) request
)
# Get externally-cached tokens if using a KVConnector. # Get externally-cached tokens if using a KVConnector.
if self.connector is not None: if self.connector is not None:
ext_tokens, load_kv_async = ( ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
self.connector.get_num_new_matched_tokens( request, num_new_local_computed_tokens
request, num_new_local_computed_tokens)) )
if ext_tokens is None: if ext_tokens is None:
# The request cannot be scheduled because # The request cannot be scheduled because
@@ -372,8 +363,7 @@ class RecomputeScheduler(Scheduler):
num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens + num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
num_external_computed_tokens)
else: else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0 # KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed. # after async KV recvs are completed.
@@ -401,8 +391,7 @@ class RecomputeScheduler(Scheduler):
# chunked prefill has to be enabled explicitly to allow # chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked # pooling requests to be chunked
if (not self.scheduler_config.enable_chunked_prefill if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
and num_new_tokens > token_budget):
# If chunked_prefill is disabled, # If chunked_prefill is disabled,
# we can stop the scheduling here. # we can stop the scheduling here.
break break
@@ -433,9 +422,7 @@ class RecomputeScheduler(Scheduler):
# extra block gets allocated which # extra block gets allocated which
# creates a mismatch between the number # creates a mismatch between the number
# of local and remote blocks. # of local and remote blocks.
effective_lookahead_tokens = (0 if request.num_computed_tokens effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
== 0 else
self.num_lookahead_tokens)
# Determine if we need to allocate cross-attention blocks. # Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs: if self.is_encoder_decoder and request.has_encoder_inputs:
@@ -443,8 +430,7 @@ class RecomputeScheduler(Scheduler):
# always padded to the maximum length. If we support other # always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we # encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed. # want to only allocate what is needed.
num_encoder_tokens = ( num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
self.scheduler_config.max_num_encoder_input_tokens)
else: else:
num_encoder_tokens = 0 num_encoder_tokens = 0
@@ -488,20 +474,17 @@ class RecomputeScheduler(Scheduler):
req_index += 1 req_index += 1
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
scheduled_timestamp)
if request.status == RequestStatus.WAITING: if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request) scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED: elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request) scheduled_resumed_reqs.append(request)
else: else:
raise RuntimeError( raise RuntimeError(f"Invalid request status: {request.status}")
f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = ( req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
self.kv_cache_manager.get_blocks(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
@@ -511,8 +494,7 @@ class RecomputeScheduler(Scheduler):
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = ( scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
encoder_inputs_to_schedule)
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -522,8 +504,7 @@ class RecomputeScheduler(Scheduler):
for i in external_load_encoder_input: for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None: if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc( self.ec_connector.update_state_after_alloc(request, i)
request, i)
# Put back any skipped requests at the head of the waiting queue # Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests: if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests) self.waiting.prepend_requests(skipped_waiting_requests)
@@ -537,20 +518,15 @@ class RecomputeScheduler(Scheduler):
# Since some requests in the RUNNING queue may not be scheduled in # Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than # this step, the total number of scheduled requests can be smaller than
# len(self.running). # len(self.running).
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
scheduled_running_reqs) <= len(self.running)
# Get the longest common prefix among all requests in the running queue. # Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention. # This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len( num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
self.kv_cache_config.kv_cache_groups) with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
with record_function_or_nullcontext(
"schedule: get_num_common_prefix_blocks"):
if self.running: if self.running:
any_request = self.running[0] any_request = self.running[0]
num_common_prefix_blocks = ( num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id))
# Construct the scheduler output. # Construct the scheduler output.
if self.use_v2_model_runner: if self.use_v2_model_runner:
@@ -561,17 +537,16 @@ class RecomputeScheduler(Scheduler):
req, req,
req_to_new_blocks[req.request_id].get_block_ids(), req_to_new_blocks[req.request_id].get_block_ids(),
req._all_token_ids, req._all_token_ids,
) for req in scheduled_new_reqs )
for req in scheduled_new_reqs
] ]
else: else:
new_reqs_data = [ new_reqs_data = [
NewRequestData.from_request( NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs for req in scheduled_new_reqs
] ]
with record_function_or_nullcontext( with record_function_or_nullcontext("schedule: make_cached_request_data"):
"schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data( cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs, scheduled_running_reqs,
scheduled_resumed_reqs, scheduled_resumed_reqs,
@@ -592,15 +567,13 @@ class RecomputeScheduler(Scheduler):
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs, scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks, num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids={req.request_id preempted_req_ids={req.request_id for req in preempted_reqs},
for req in preempted_reqs},
# finished_req_ids is an existing state in the scheduler, # finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step. # instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between # It contains the request IDs that are finished in between
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager. free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
get_freed_mm_hashes(),
recomputed_reqs=recomputed_reqs, recomputed_reqs=recomputed_reqs,
) )
@@ -609,14 +582,12 @@ class RecomputeScheduler(Scheduler):
# 2. Wrap up all the KV cache load / save ops into an opaque object # 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector # 3. Clear the internal states of the connector
if self.connector is not None: if self.connector is not None:
meta: KVConnectorMetadata = self.connector.build_connector_meta( meta: KVConnectorMetadata = self.connector.build_connector_meta(scheduler_output)
scheduler_output)
scheduler_output.kv_connector_metadata = meta scheduler_output.kv_connector_metadata = meta
# Build the connector meta for ECConnector # Build the connector meta for ECConnector
if self.ec_connector is not None: if self.ec_connector is not None:
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta( ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(scheduler_output)
scheduler_output)
scheduler_output.ec_connector_metadata = ec_meta scheduler_output.ec_connector_metadata = ec_meta
with record_function_or_nullcontext("schedule: update_after_schedule"): with record_function_or_nullcontext("schedule: update_after_schedule"):
@@ -639,8 +610,8 @@ class RecomputeScheduler(Scheduler):
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: KVConnectorStats | None = ( kv_connector_stats: KVConnectorStats | None = (
kv_connector_output.kv_connector_stats kv_connector_output.kv_connector_stats if kv_connector_output else None
if kv_connector_output else None) )
if kv_connector_stats and self.connector: if kv_connector_stats and self.connector:
kv_stats = self.connector.get_kv_connector_stats() kv_stats = self.connector.get_kv_connector_stats()
if kv_stats: if kv_stats:
@@ -651,8 +622,7 @@ class RecomputeScheduler(Scheduler):
# These blocks contain externally computed tokens that failed to # These blocks contain externally computed tokens that failed to
# load. Identify affected requests and adjust their computed token # load. Identify affected requests and adjust their computed token
# count to trigger recomputation of the invalid blocks. # count to trigger recomputation of the invalid blocks.
failed_kv_load_req_ids = self._handle_invalid_blocks( failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids)
kv_connector_output.invalid_block_ids)
# return recomputed requests as EngineCoreOutput # return recomputed requests as EngineCoreOutput
if scheduler_output.recomputed_reqs is not None: if scheduler_output.recomputed_reqs is not None:
@@ -663,7 +633,8 @@ class RecomputeScheduler(Scheduler):
finish_reason=FinishReason.STOP, finish_reason=FinishReason.STOP,
new_token_ids=[req_info.output_token_ids[-1]], new_token_ids=[req_info.output_token_ids[-1]],
stop_reason="recomputed", stop_reason="recomputed",
)) )
)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best # the below loop can be a performance bottleneck. We should do our best
@@ -683,11 +654,9 @@ class RecomputeScheduler(Scheduler):
continue continue
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = (sampled_token_ids[req_index] generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else []
if sampled_token_ids else [])
scheduled_spec_token_ids = ( scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids: if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids) num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1 num_accepted = len(generated_token_ids) - 1
@@ -717,15 +686,13 @@ class RecomputeScheduler(Scheduler):
# Check for stop and update request status. # Check for stop and update request status.
if new_token_ids: if new_token_ids:
new_token_ids, stopped = self._update_request_with_output( new_token_ids, stopped = self._update_request_with_output(request, new_token_ids)
request, new_token_ids)
# Stop checking for pooler models. # Stop checking for pooler models.
pooler_output = None pooler_output = None
if pooler_outputs: if pooler_outputs:
pooler_output = pooler_outputs[req_index] pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, self.max_model_len, stopped = check_stop(request, self.max_model_len, pooler_output)
pooler_output)
if stopped: if stopped:
kv_transfer_params = self._free_request(request) kv_transfer_params = self._free_request(request)
@@ -735,19 +702,14 @@ class RecomputeScheduler(Scheduler):
stopped_preempted_reqs.add(request) stopped_preempted_reqs.add(request)
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
if (request.sampling_params is not None if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs:
and request.sampling_params.logprobs is not None new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
and logprobs):
new_logprobs = logprobs.slice_request(req_index,
len(new_token_ids))
if new_token_ids and self.structured_output_manager.should_advance( if new_token_ids and self.structured_output_manager.should_advance(request):
request):
struct_output_request = request.structured_output_request struct_output_request = request.structured_output_request
assert struct_output_request is not None assert struct_output_request is not None
assert struct_output_request.grammar is not None assert struct_output_request.grammar is not None
struct_output_request.grammar.accept_tokens( struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
req_id, new_token_ids)
if num_nans_in_logits is not None and req_id in num_nans_in_logits: if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
@@ -770,7 +732,8 @@ class RecomputeScheduler(Scheduler):
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens, num_cached_tokens=request.num_cached_tokens,
num_nans_in_logits=request.num_nans_in_logits, num_nans_in_logits=request.num_nans_in_logits,
)) )
)
else: else:
# Invariant: EngineCore returns no partial prefill outputs. # Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors assert not prompt_logprobs_tensors
@@ -805,10 +768,7 @@ class RecomputeScheduler(Scheduler):
# Create EngineCoreOutputs for all clients that have requests with # Create EngineCoreOutputs for all clients that have requests with
# outputs in this step. # outputs in this step.
engine_core_outputs = { engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()}
client_index: EngineCoreOutputs(outputs=outs)
for client_index, outs in outputs.items()
}
finished_req_ids = self.finished_req_ids_dict finished_req_ids = self.finished_req_ids_dict
if finished_req_ids: if finished_req_ids:
@@ -819,12 +779,10 @@ class RecomputeScheduler(Scheduler):
if (eco := engine_core_outputs.get(client_index)) is not None: if (eco := engine_core_outputs.get(client_index)) is not None:
eco.finished_requests = finished_set eco.finished_requests = finished_set
else: else:
engine_core_outputs[client_index] = EngineCoreOutputs( engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set)
finished_requests=finished_set)
finished_req_ids.clear() finished_req_ids.clear()
if (stats := self.make_stats(spec_decoding_stats, if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats)) is not None:
kv_connector_stats)) is not None:
# Return stats to only one of the front-ends. # Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None: if (eco := next(iter(engine_core_outputs.values()), None)) is None:
# We must return the stats even if there are no request # We must return the stats even if there are no request
@@ -836,6 +794,5 @@ class RecomputeScheduler(Scheduler):
class AsyncRecomputeScheduler(AsyncScheduler, RecomputeScheduler): class AsyncRecomputeScheduler(AsyncScheduler, RecomputeScheduler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@@ -16,7 +16,6 @@
# #
import os import os
import time import time
from typing import Optional
import pandas as pd import pandas as pd
from vllm.config import VllmConfig from vllm.config import VllmConfig
@@ -25,8 +24,7 @@ from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import (SchedulingPolicy, from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
create_request_queue)
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreEventType from vllm.v1.engine import EngineCoreEventType
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -43,8 +41,9 @@ class BudgetRefiner:
if not self.enabled: if not self.enabled:
return return
logger.info( logger.info(
"Dynamic batch is enabled with SLO limit: {}, and chunked prefill is forced to be activated because dynamic batch relies on it" "Dynamic batch is enabled with SLO limit: {}, and chunked prefill is "
.format(str(slo_limit))) "forced to be activated because dynamic batch relies on it".format(str(slo_limit))
)
self.lookup: dict[tuple[int, int], int] = {} self.lookup: dict[tuple[int, int], int] = {}
self.context_keys: set[int] = set() self.context_keys: set[int] = set()
self.dnum_keys: set[int] = set() self.dnum_keys: set[int] = set()
@@ -61,19 +60,20 @@ class BudgetRefiner:
"The dynamic batching feature requires the lookup table " "The dynamic batching feature requires the lookup table "
"'profile_table.csv', but it was not found at '%s'. " "'profile_table.csv', but it was not found at '%s'. "
"Please download the corresponding table file.", "Please download the corresponding table file.",
table_file_path) table_file_path,
)
self.enabled = False self.enabled = False
return return
else: else:
df = pd.read_csv(table_file_path) df = pd.read_csv(table_file_path)
grouped = df.groupby(['ctx_len', 'd_num']) grouped = df.groupby(["ctx_len", "d_num"])
for (ctx_len, d_num), group in grouped: for (ctx_len, d_num), group in grouped:
valid = group[group['cost'] <= slo_limit] valid = group[group["cost"] <= slo_limit]
if not valid.empty: if not valid.empty:
max_row = valid.loc[valid['chunk_size'].idxmax()] max_row = valid.loc[valid["chunk_size"].idxmax()]
assert isinstance(ctx_len, int), "ctx_len must be an integer" assert isinstance(ctx_len, int), "ctx_len must be an integer"
assert isinstance(d_num, int), "d_num must be an integer" assert isinstance(d_num, int), "d_num must be an integer"
self.lookup[(ctx_len, d_num)] = int(max_row['chunk_size']) self.lookup[(ctx_len, d_num)] = int(max_row["chunk_size"])
self.context_keys.add(ctx_len) self.context_keys.add(ctx_len)
self.dnum_keys.add(d_num) self.dnum_keys.add(d_num)
self.context_keys = set(sorted(self.context_keys)) self.context_keys = set(sorted(self.context_keys))
@@ -97,7 +97,10 @@ class BudgetRefiner:
logger.warn(f"Table miss for ctx,dnum{aligned_ctx, aligned_dnum}") logger.warn(f"Table miss for ctx,dnum{aligned_ctx, aligned_dnum}")
budget = self.default_budget budget = self.default_budget
# For debug. # For debug.
# logger.info(f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, raw ctx,dnum {num_deocde_tokens, num_decode}") # logger.info(
# f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, "
# f"raw ctx,dnum {num_deocde_tokens, num_decode}"
# )
return budget return budget
def refine_budget(self, running_request, budget): def refine_budget(self, running_request, budget):
@@ -106,9 +109,8 @@ class BudgetRefiner:
return budget return budget
# assume all running request will be scheduled. # assume all running request will be scheduled.
num_decode_token_lst = [ num_decode_token_lst = [
req.num_tokens_with_spec \ req.num_tokens_with_spec for req in running_request if req.num_computed_tokens >= req.num_prompt_tokens
for req in running_request \ ]
if req.num_computed_tokens >= req.num_prompt_tokens ]
num_decode = len(num_decode_token_lst) num_decode = len(num_decode_token_lst)
if num_decode <= 0: if num_decode <= 0:
return budget return budget
@@ -125,18 +127,25 @@ class SchedulerDynamicBatch(Scheduler):
vllm_config: VllmConfig, vllm_config: VllmConfig,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager, structured_output_manager: StructuredOutputManager,
block_size: Optional[int] = None, block_size: int | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False, include_finished_set: bool = False,
log_stats: bool = False, log_stats: bool = False,
) -> None: ) -> None:
super().__init__(vllm_config, kv_cache_config, super().__init__(
structured_output_manager, block_size, mm_registry, vllm_config,
include_finished_set, log_stats) kv_cache_config,
structured_output_manager,
block_size,
mm_registry,
include_finished_set,
log_stats,
)
self.running: list[Request] = [] self.running: list[Request] = []
self.budget_refiner = BudgetRefiner( self.budget_refiner = BudgetRefiner(
default_budget=self.scheduler_config.max_num_batched_tokens, default_budget=self.scheduler_config.max_num_batched_tokens,
slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch) slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch,
)
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
# NOTE: This scheduling algorithm is developed based on the "super.schedule()" # NOTE: This scheduling algorithm is developed based on the "super.schedule()"
@@ -159,20 +168,13 @@ class SchedulerDynamicBatch(Scheduler):
req_to_new_blocks: dict[str, KVCacheBlocks] = {} req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {} num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens token_budget = self.max_num_scheduled_tokens
token_budget = self.budget_refiner.refine_budget( token_budget = self.budget_refiner.refine_budget(self.running, token_budget)
self.running, token_budget)
# NOTE: We move the prefill requests to the end of the self.running # NOTE: We move the prefill requests to the end of the self.running
# list and keep the relative order unchanged. This rearrangement makes this # list and keep the relative order unchanged. This rearrangement makes this
# scheduling algorithm a strict decode-first chunked prefills. # scheduling algorithm a strict decode-first chunked prefills.
d_lst = [ d_lst = [req for req in self.running if req.num_computed_tokens >= req.num_prompt_tokens]
req for req in self.running p_lst = [req for req in self.running if req.num_computed_tokens < req.num_prompt_tokens]
if req.num_computed_tokens >= req.num_prompt_tokens
]
p_lst = [
req for req in self.running
if req.num_computed_tokens < req.num_prompt_tokens
]
self.running = d_lst + p_lst self.running = d_lst + p_lst
# Encoder-related. # Encoder-related.
@@ -189,30 +191,26 @@ class SchedulerDynamicBatch(Scheduler):
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
num_new_tokens = (request.num_tokens_with_spec + num_new_tokens = (
request.num_output_placeholders - request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
request.num_computed_tokens) )
if (0 < self.scheduler_config.long_prefill_token_threshold < if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens): num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len. # Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding. # This is necessary when using spec decoding.
num_new_tokens = min( num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs. # Schedule encoder inputs.
encoder_inputs_to_schedule = None encoder_inputs_to_schedule = None
new_encoder_compute_budget = encoder_compute_budget new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs: if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens, (encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget) = (
new_encoder_compute_budget self._try_schedule_encoder_inputs(
) = self._try_schedule_encoder_inputs( request, request.num_computed_tokens, num_new_tokens, encoder_compute_budget
request, request.num_computed_tokens, num_new_tokens, )
encoder_compute_budget) )
if num_new_tokens == 0: if num_new_tokens == 0:
# The request cannot be scheduled because one of the following # The request cannot be scheduled because one of the following
@@ -231,9 +229,8 @@ class SchedulerDynamicBatch(Scheduler):
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request, num_new_tokens, num_lookahead_tokens=self.num_lookahead_tokens
num_new_tokens, )
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
# Preempt the lowest-priority request. # Preempt the lowest-priority request.
@@ -253,8 +250,7 @@ class SchedulerDynamicBatch(Scheduler):
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(EngineCoreEventType.PREEMPTED, scheduled_timestamp)
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
self.waiting.prepend_request(preempted_req) self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)
@@ -279,19 +275,15 @@ class SchedulerDynamicBatch(Scheduler):
# Speculative decode related. # Speculative decode related.
if request.spec_token_ids: if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens + num_scheduled_spec_tokens = num_new_tokens + request.num_computed_tokens - request.num_tokens
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens. # Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:] del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = ( scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
request.spec_token_ids)
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = ( scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
encoder_inputs_to_schedule)
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -301,8 +293,10 @@ class SchedulerDynamicBatch(Scheduler):
scheduled_loras: set[int] = set() scheduled_loras: set[int] = set()
if self.lora_config: if self.lora_config:
scheduled_loras = set( scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs req.lora_request.lora_int_id
if req.lora_request and req.lora_request.lora_int_id > 0) for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0
)
assert len(scheduled_loras) <= self.lora_config.max_loras assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be # Use a temporary RequestQueue to collect requests that need to be
@@ -323,9 +317,7 @@ class SchedulerDynamicBatch(Scheduler):
if is_ready: if is_ready:
request.status = RequestStatus.WAITING request.status = RequestStatus.WAITING
else: else:
logger.debug( logger.debug("%s is still in WAITING_FOR_REMOTE_KVS state.", request.request_id)
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id)
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
@@ -343,9 +335,14 @@ class SchedulerDynamicBatch(Scheduler):
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if (self.lora_config and request.lora_request and if (
(len(scheduled_loras) == self.lora_config.max_loras and self.lora_config
request.lora_request.lora_int_id not in scheduled_loras)): and request.lora_request
and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras
)
):
# Scheduling would exceed max_loras, skip. # Scheduling would exceed max_loras, skip.
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
@@ -357,15 +354,15 @@ class SchedulerDynamicBatch(Scheduler):
# Get already-cached tokens. # Get already-cached tokens.
if request.num_computed_tokens == 0: if request.num_computed_tokens == 0:
# Get locally-cached tokens. # Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = \ new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
self.kv_cache_manager.get_computed_blocks( request
request) )
# Get externally-cached tokens if using a KVConnector. # Get externally-cached tokens if using a KVConnector.
if self.connector is not None: if self.connector is not None:
num_external_computed_tokens, load_kv_async = ( num_external_computed_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
self.connector.get_num_new_matched_tokens( request, num_new_local_computed_tokens
request, num_new_local_computed_tokens)) )
if num_external_computed_tokens is None: if num_external_computed_tokens is None:
# The request cannot be scheduled because # The request cannot be scheduled because
@@ -376,13 +373,11 @@ class SchedulerDynamicBatch(Scheduler):
continue continue
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens + num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
num_external_computed_tokens)
# KVTransfer: WAITING reqs have num_computed_tokens > 0 # KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed. # after async KV recvs are completed.
else: else:
new_computed_blocks = ( new_computed_blocks = self.kv_cache_manager.create_empty_block_list()
self.kv_cache_manager.create_empty_block_list())
num_new_local_computed_tokens = 0 num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens num_computed_tokens = request.num_computed_tokens
@@ -399,15 +394,12 @@ class SchedulerDynamicBatch(Scheduler):
# `request.num_prompt_tokens` to consider the resumed # `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens. # requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
< num_new_tokens): num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow # chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked # pooling requests to be chunked
if not self.scheduler_config.enable_chunked_prefill and \ if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
num_new_tokens > token_budget:
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
@@ -417,11 +409,11 @@ class SchedulerDynamicBatch(Scheduler):
# Schedule encoder inputs. # Schedule encoder inputs.
if request.has_encoder_inputs: if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens, (encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget, _) = (
new_encoder_compute_budget, self._try_schedule_encoder_inputs(
_) = self._try_schedule_encoder_inputs( request, num_computed_tokens, num_new_tokens, encoder_compute_budget
request, num_computed_tokens, num_new_tokens, )
encoder_compute_budget) )
if num_new_tokens == 0: if num_new_tokens == 0:
# The request cannot be scheduled. # The request cannot be scheduled.
break break
@@ -431,9 +423,7 @@ class SchedulerDynamicBatch(Scheduler):
# extra block gets allocated which # extra block gets allocated which
# creates a mismatch between the number # creates a mismatch between the number
# of local and remote blocks. # of local and remote blocks.
effective_lookahead_tokens = (0 if request.num_computed_tokens effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
== 0 else
self.num_lookahead_tokens)
# Determine if we need to allocate cross-attention blocks. # Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs: if self.is_encoder_decoder and request.has_encoder_inputs:
@@ -441,8 +431,7 @@ class SchedulerDynamicBatch(Scheduler):
# always padded to the maximum length. If we support other # always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we # encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed. # want to only allocate what is needed.
num_encoder_tokens =\ num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
self.scheduler_config.max_num_encoder_input_tokens
else: else:
num_encoder_tokens = 0 num_encoder_tokens = 0
@@ -484,20 +473,17 @@ class SchedulerDynamicBatch(Scheduler):
req_index += 1 req_index += 1
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
scheduled_timestamp)
if request.status == RequestStatus.WAITING: if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request) scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED: elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request) scheduled_resumed_reqs.append(request)
else: else:
raise RuntimeError( raise RuntimeError(f"Invalid request status: {request.status}")
f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = ( req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
self.kv_cache_manager.get_blocks(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
@@ -507,8 +493,7 @@ class SchedulerDynamicBatch(Scheduler):
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = ( scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
encoder_inputs_to_schedule)
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -526,22 +511,17 @@ class SchedulerDynamicBatch(Scheduler):
# Since some requests in the RUNNING queue may not be scheduled in # Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than # this step, the total number of scheduled requests can be smaller than
# len(self.running). # len(self.running).
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
len(scheduled_running_reqs) <= len(self.running))
# Get the longest common prefix among all requests in the running queue. # Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention. # This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len( num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
self.kv_cache_config.kv_cache_groups)
if self.running: if self.running:
any_request = self.running[0] any_request = self.running[0]
num_common_prefix_blocks = ( num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id))
# Construct the scheduler output. # Construct the scheduler output.
new_reqs_data = [ new_reqs_data = [
NewRequestData.from_request( NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs for req in scheduled_new_reqs
] ]
cached_reqs_data = self._make_cached_request_data( cached_reqs_data = self._make_cached_request_data(
@@ -564,8 +544,7 @@ class SchedulerDynamicBatch(Scheduler):
# It contains the request IDs that are finished in between # It contains the request IDs that are finished in between
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager. free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
get_freed_mm_hashes(),
) )
# NOTE(Kuntai): this function is designed for multiple purposes: # NOTE(Kuntai): this function is designed for multiple purposes:

View File

@@ -14,61 +14,50 @@
# limitations under the License. # limitations under the License.
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \ from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
DeviceCommunicatorBase
class NPUCommunicator(DeviceCommunicatorBase): class NPUCommunicator(DeviceCommunicatorBase):
def __init__(
def __init__(self, self,
cpu_group: dist.ProcessGroup, cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None, device: torch.device | None = None,
device_group: Optional[dist.ProcessGroup] = None, device_group: dist.ProcessGroup | None = None,
unique_name: str = ""): unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(cpu_group, device, device_group, unique_name)
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator # TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank # init device according to rank
self.device = torch.npu.current_device() self.device = torch.npu.current_device()
def all_to_all(self, def all_to_all(
input_: torch.Tensor, self,
scatter_dim: int = 0, input_: torch.Tensor,
gather_dim: int = -1, scatter_dim: int = 0,
scatter_sizes: Optional[List[int]] = None, gather_dim: int = -1,
gather_sizes: Optional[List[int]] = None) -> torch.Tensor: scatter_sizes: list[int] | None = None,
gather_sizes: list[int] | None = None,
) -> torch.Tensor:
if scatter_dim < 0: if scatter_dim < 0:
scatter_dim += input_.dim() scatter_dim += input_.dim()
if gather_dim < 0: if gather_dim < 0:
gather_dim += input_.dim() gather_dim += input_.dim()
if scatter_sizes is not None and gather_sizes is not None: if scatter_sizes is not None and gather_sizes is not None:
input_list = [ input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)]
t.contiguous()
for t in torch.split(input_, scatter_sizes, scatter_dim)
]
output_list = [] output_list = []
tensor_shape_base = input_list[self.rank].size() tensor_shape_base = input_list[self.rank].size()
for i in range(self.world_size): for i in range(self.world_size):
tensor_shape = list(tensor_shape_base) tensor_shape = list(tensor_shape_base)
tensor_shape[gather_dim] = gather_sizes[i] tensor_shape[gather_dim] = gather_sizes[i]
output_list.append( output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
torch.empty(tensor_shape,
dtype=input_.dtype,
device=input_.device))
else: else:
input_list = [ input_list = [t.contiguous() for t in torch.tensor_split(input_, self.world_size, scatter_dim)]
t.contiguous() for t in torch.tensor_split( output_list = [torch.empty_like(input_list[i]) for i in range(self.world_size)]
input_, self.world_size, scatter_dim)
]
output_list = [
torch.empty_like(input_list[i]) for i in range(self.world_size)
]
dist.all_to_all(output_list, input_list, group=self.device_group) dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous() output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
# #
from typing import Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -24,18 +23,23 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import logger from vllm.logger import logger
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum, HCCLLibrary,
hcclRedOpTypeEnum, hcclUniqueId) aclrtStream_t,
buffer_type,
hcclComm_t,
hcclDataTypeEnum,
hcclRedOpTypeEnum,
hcclUniqueId,
)
from vllm_ascend.utils import current_stream from vllm_ascend.utils import current_stream
class PyHcclCommunicator: class PyHcclCommunicator:
def __init__( def __init__(
self, self,
group: Union[ProcessGroup, StatelessProcessGroup], group: ProcessGroup | StatelessProcessGroup,
device: Union[int, str, torch.device], device: int | str | torch.device,
library_path: Optional[str] = None, library_path: str | None = None,
): ):
""" """
Args: Args:
@@ -52,7 +56,8 @@ class PyHcclCommunicator:
if not isinstance(group, StatelessProcessGroup): if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized() assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.HCCL, ( assert dist.get_backend(group) != dist.Backend.HCCL, (
"PyHcclCommunicator should be attached to a non-HCCL group.") "PyHcclCommunicator should be attached to a non-HCCL group."
)
# note: this rank is the rank in the group # note: this rank is the rank in the group
self.rank = dist.get_rank(group) self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group) self.world_size = dist.get_world_size(group)
@@ -113,8 +118,7 @@ class PyHcclCommunicator:
# `torch.npu.device` is a context manager that changes the # `torch.npu.device` is a context manager that changes the
# current npu device to the specified one # current npu device to the specified one
with torch.npu.device(device): with torch.npu.device(device):
self.comm: hcclComm_t = self.hccl.hcclCommInitRank( self.comm: hcclComm_t = self.hccl.hcclCommInitRank(self.world_size, self.unique_id, self.rank)
self.world_size, self.unique_id, self.rank)
stream = current_stream() stream = current_stream()
# A small all_reduce for warmup. # A small all_reduce for warmup.
@@ -123,43 +127,48 @@ class PyHcclCommunicator:
stream.synchronize() stream.synchronize()
del data del data
def all_reduce(self, def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor:
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
if self.disabled: if self.disabled:
return None return None
# hccl communicator created on a specific device # hccl communicator created on a specific device
# will only work on tensors on the same device # will only work on tensors on the same device
# otherwise it will cause "illegal memory access" # otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, ( assert in_tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, " f"this hccl communicator is created to work on {self.device}, but the input tensor is on {in_tensor.device}"
f"but the input tensor is on {in_tensor.device}") )
out_tensor = torch.empty_like(in_tensor) out_tensor = torch.empty_like(in_tensor)
if stream is None: if stream is None:
stream = current_stream() stream = current_stream()
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()), self.hccl.hcclAllReduce(
buffer_type(out_tensor.data_ptr()), buffer_type(in_tensor.data_ptr()),
in_tensor.numel(), buffer_type(out_tensor.data_ptr()),
hcclDataTypeEnum.from_torch(in_tensor.dtype), in_tensor.numel(),
hcclRedOpTypeEnum.from_torch(op), self.comm, hcclDataTypeEnum.from_torch(in_tensor.dtype),
aclrtStream_t(stream.npu_stream)) hcclRedOpTypeEnum.from_torch(op),
self.comm,
aclrtStream_t(stream.npu_stream),
)
return out_tensor return out_tensor
def broadcast(self, tensor: torch.Tensor, src: int, stream=None): def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled: if self.disabled:
return return
assert tensor.device == self.device, ( assert tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, " f"this hccl communicator is created to work on {self.device}, but the input tensor is on {tensor.device}"
f"but the input tensor is on {tensor.device}") )
if stream is None: if stream is None:
stream = current_stream() stream = current_stream()
if src == self.rank: if src == self.rank:
buffer = buffer_type(tensor.data_ptr()) buffer = buffer_type(tensor.data_ptr())
else: else:
buffer = buffer_type(tensor.data_ptr()) buffer = buffer_type(tensor.data_ptr())
self.hccl.hcclBroadcast(buffer, tensor.numel(), self.hccl.hcclBroadcast(
hcclDataTypeEnum.from_torch(tensor.dtype), src, buffer,
self.comm, aclrtStream_t(stream.npu_stream)) tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
aclrtStream_t(stream.npu_stream),
)

View File

@@ -18,7 +18,7 @@
import ctypes import ctypes
import platform import platform
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any
import torch import torch
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
@@ -107,69 +107,74 @@ class hcclRedOpTypeEnum:
class Function: class Function:
name: str name: str
restype: Any restype: Any
argtypes: List[Any] argtypes: list[Any]
class HCCLLibrary: class HCCLLibrary:
exported_functions = [ exported_functions = [
# const char* HcclGetErrorString(HcclResult code); # const char* HcclGetErrorString(HcclResult code);
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]), Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo); # HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
Function("HcclGetRootInfo", hcclResult_t, Function("HcclGetRootInfo", hcclResult_t, [ctypes.POINTER(hcclUniqueId)]),
[ctypes.POINTER(hcclUniqueId)]),
# HcclResult HcclCommInitRootInfo( # HcclResult HcclCommInitRootInfo(
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm); # uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer # note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
Function("HcclCommInitRootInfo", hcclResult_t, [ Function(
ctypes.c_int, "HcclCommInitRootInfo",
ctypes.POINTER(hcclUniqueId), hcclResult_t,
ctypes.c_int, [
ctypes.POINTER(hcclComm_t), ctypes.c_int,
]), ctypes.POINTER(hcclUniqueId),
ctypes.c_int,
ctypes.POINTER(hcclComm_t),
],
),
# HcclResult HcclAllReduce( # HcclResult HcclAllReduce(
# void *sendBuf, void *recvBuf, uint64_t count, # void *sendBuf, void *recvBuf, uint64_t count,
# HcclDataType dataType, HcclReduceOp op, HcclComm comm, # HcclDataType dataType, HcclReduceOp op, HcclComm comm,
# aclrtStream stream); # aclrtStream stream);
Function("HcclAllReduce", hcclResult_t, [ Function(
buffer_type, "HcclAllReduce",
buffer_type, hcclResult_t,
ctypes.c_size_t, [
hcclDataType_t, buffer_type,
hcclRedOp_t, buffer_type,
hcclComm_t, ctypes.c_size_t,
aclrtStream_t, hcclDataType_t,
]), hcclRedOp_t,
hcclComm_t,
aclrtStream_t,
],
),
# HcclResult HcclBroadcast( # HcclResult HcclBroadcast(
# void *buf, uint64_t count, # void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root, # HcclDataType dataType, uint32_t root,
# HcclComm comm, aclrtStream stream); # HcclComm comm, aclrtStream stream);
Function("HcclBroadcast", hcclResult_t, [ Function(
buffer_type, "HcclBroadcast",
ctypes.c_size_t, hcclResult_t,
hcclDataType_t, [
ctypes.c_int, buffer_type,
hcclComm_t, ctypes.c_size_t,
aclrtStream_t, hcclDataType_t,
]), ctypes.c_int,
hcclComm_t,
aclrtStream_t,
],
),
# HcclResult HcclCommDestroy(HcclComm comm); # HcclResult HcclCommDestroy(HcclComm comm);
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]), Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
] ]
# class attribute to store the mapping from the path to the library # class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times # to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {} path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path # class attribute to store the mapping from library path
# to the correspongding directory # to the correspongding directory
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
def __init__(self, so_file: str | None = None):
so_file = so_file or find_hccl_library() so_file = so_file or find_hccl_library()
try: try:
@@ -185,12 +190,14 @@ class HCCLLibrary:
"or it does not support the current platform %s. " "or it does not support the current platform %s. "
"If you already have the library, please set the " "If you already have the library, please set the "
"environment variable HCCL_SO_PATH" "environment variable HCCL_SO_PATH"
" to point to the correct hccl library path.", so_file, " to point to the correct hccl library path.",
platform.platform()) so_file,
platform.platform(),
)
raise e raise e
if so_file not in HCCLLibrary.path_to_dict_mapping: if so_file not in HCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {} _funcs: dict[str, Any] = {}
for func in HCCLLibrary.exported_functions: for func in HCCLLibrary.exported_functions:
f = getattr(self.lib, func.name) f = getattr(self.lib, func.name)
f.restype = func.restype f.restype = func.restype
@@ -209,34 +216,37 @@ class HCCLLibrary:
def hcclGetUniqueId(self) -> hcclUniqueId: def hcclGetUniqueId(self) -> hcclUniqueId:
unique_id = hcclUniqueId() unique_id = hcclUniqueId()
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"]( self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](ctypes.byref(unique_id)))
ctypes.byref(unique_id)))
return unique_id return unique_id
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, rank: int) -> hcclComm_t:
rank: int) -> hcclComm_t:
comm = hcclComm_t() comm = hcclComm_t()
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"]( self.HCCL_CHECK(
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm))) self._funcs["HcclCommInitRootInfo"](world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm))
)
return comm return comm
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, def hcclAllReduce(
count: int, datatype: int, op: int, comm: hcclComm_t, self,
stream: aclrtStream_t) -> None: sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: hcclComm_t,
stream: aclrtStream_t,
) -> None:
# `datatype` actually should be `hcclDataType_t` # `datatype` actually should be `hcclDataType_t`
# and `op` should be `hcclRedOp_t` # and `op` should be `hcclRedOp_t`
# both are aliases of `ctypes.c_int` # both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically # by ctypes automatically
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream))
datatype, op, comm,
stream))
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int, def hcclBroadcast(
root: int, comm: hcclComm_t, self, buf: buffer_type, count: int, datatype: int, root: int, comm: hcclComm_t, stream: aclrtStream_t
stream: aclrtStream_t) -> None: ) -> None:
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, root, comm, stream))
root, comm, stream))
def hcclCommDestroy(self, comm: hcclComm_t) -> None: def hcclCommDestroy(self, comm: hcclComm_t) -> None:
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm)) self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))