[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #10) (#6173)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|`vllm_ascend/ops/layer_shard_linear.py`|
|`vllm_ascend/ops/linear.py`|
|`vllm_ascend/ops/linear_op.py`|
|`vllm_ascend/worker/worker.py`|
| ` vllm_ascend/patch/worker/patch_bert.py` |
| ` vllm_ascend/patch/worker/patch_deepseek.py` |
| ` vllm_ascend/patch/worker/patch_distributed.py` |
| ` vllm_ascend/patch/worker/patch_module.py` |
| ` vllm_ascend/patch/worker/patch_multimodal_merge.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next_mtp.py` |
| ` vllm_ascend/patch/worker/patch_rejection_sampler.py` |
| ` vllm_ascend/patch/worker/patch_rope.py` |
| ` vllm_ascend/patch/worker/patch_triton.py` |
| ` vllm_ascend/patch/worker/patch_unquantized_gemm.py` |
| ` vllm_ascend/patch/worker/patch_v2_egale.py` |
|` vllm_ascend/worker/npu_input_batch.py`|
|` vllm_ascend/worker/v2/aclgraph_utils.py`|
|` vllm_ascend/worker/v2/attn_utils.py`|
|` vllm_ascend/worker/v2/model_runner.py`|
|` vllm_ascend/worker/v2/sample/gumbel.py`|
|` vllm_ascend/worker/v2/sample/penalties.py`|
|` vllm_ascend/worker/v2/sample/sampler.py`|
|` vllm_ascend/worker/v2/spec_decode/__init__.py`|
|` vllm_ascend/worker/v2/spec_decode/eagle.py`|
|` vllm_ascend/worker/v2/states.py`|
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
SILONG ZENG
2026-02-06 15:35:06 +08:00
committed by GitHub
parent 65b7f716e6
commit 19b5d44ea8
33 changed files with 938 additions and 1243 deletions

View File

@@ -51,17 +51,6 @@ line-length = 120
# Folder to be modified # Folder to be modified
exclude = [ exclude = [
"tests/**", "tests/**",
# (10)
"vllm_ascend/ops/*linear*.py",
"vllm_ascend/worker/worker.py",
"vllm_ascend/distributed/parallel_state.py",
"vllm_ascend/distributed/utils.py",
"vllm_ascend/xlite/*.py",
"vllm_ascend/patch/worker/patch_*.py",
"vllm_ascend/worker/v2/**",
"vllm_ascend/worker/npu_input_batch.py",
"vllm_ascend/ops/rotary_embedding.py",
] ]
[tool.ruff.lint] [tool.ruff.lint]

View File

@@ -1,35 +1,33 @@
from typing import Optional
import torch import torch
from vllm.config import ParallelConfig, get_current_vllm_config from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group, from vllm.distributed.parallel_state import GroupCoordinator, get_tp_group, get_world_group, init_model_parallel_group
get_world_group,
init_model_parallel_group)
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import enable_dsa_cp_with_layer_shard, flashcomm2_enable from vllm_ascend.utils import enable_dsa_cp_with_layer_shard, flashcomm2_enable
# Currently, mc2 op need their own group coordinator. # Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None _MC2: GroupCoordinator | None = None
# Module specific tensor parallel groups # Module specific tensor parallel groups
_MLP_TP: Optional[GroupCoordinator] = None _MLP_TP: GroupCoordinator | None = None
_OTP: Optional[GroupCoordinator] = None _OTP: GroupCoordinator | None = None
_LMTP: Optional[GroupCoordinator] = None _LMTP: GroupCoordinator | None = None
_EMBED_TP: Optional[GroupCoordinator] = None _EMBED_TP: GroupCoordinator | None = None
# flashcomm specific groups # flashcomm specific groups
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None _FLASHCOMM2_OTP: GroupCoordinator | None = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None _FLASHCOMM2_ODP: GroupCoordinator | None = None
_FC3_QUANT_X: Optional[GroupCoordinator] = None _FC3_QUANT_X: GroupCoordinator | None = None
# shard_weight across rank groups # shard_weight across rank groups
_SHARD_WEIGHT: Optional[GroupCoordinator] = None _SHARD_WEIGHT: GroupCoordinator | None = None
_P_TP: Optional[GroupCoordinator] = None _P_TP: GroupCoordinator | None = None
def init_ascend_model_parallel(parallel_config: ParallelConfig, ): def init_ascend_model_parallel(
parallel_config: ParallelConfig,
):
if model_parallel_initialized(): if model_parallel_initialized():
return return
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
@@ -43,8 +41,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
# ExternalDP is the data parallel group that is not part of the model, # ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration). # every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape( all_ranks = torch.arange(world_size).reshape(
-1, global_dp_size * parallel_config.prefill_context_parallel_size * -1, global_dp_size * parallel_config.prefill_context_parallel_size * global_tp_size
global_tp_size) )
# TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future. # TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future.
vllm_all_ranks = torch.arange(world_size).reshape( vllm_all_ranks = torch.arange(world_size).reshape(
-1, -1,
@@ -57,49 +55,35 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
pd_tp_ratio = get_ascend_config().pd_tp_ratio pd_tp_ratio = get_ascend_config().pd_tp_ratio
pd_head_ratio = get_ascend_config().pd_head_ratio pd_head_ratio = get_ascend_config().pd_head_ratio
global _P_TP global _P_TP
assert _P_TP is None, ( assert _P_TP is None, "distributed prefill tensor parallel group is already initialized"
"distributed prefill tensor parallel group is already initialized")
prefill_tensor_model_parallel_size = pd_tp_ratio prefill_tensor_model_parallel_size = pd_tp_ratio
# divide alltoall groups # divide alltoall groups
if pd_head_ratio > 1 and get_current_vllm_config( if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer:
).kv_transfer_config.is_kv_producer:
num_head_replica = get_ascend_config().num_head_replica num_head_replica = get_ascend_config().num_head_replica
remote_tp_size = global_tp_size // pd_tp_ratio remote_tp_size = global_tp_size // pd_tp_ratio
if num_head_replica <= 1: if num_head_replica <= 1:
group_ranks = all_ranks.view( group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0)
-1, prefill_tensor_model_parallel_size).unbind(0)
else: else:
group_ranks = all_ranks.clone().view( group_ranks = all_ranks.clone().view(
global_dp_size, -1, global_dp_size, -1, num_head_replica
num_head_replica) # [DP_size, num_head, num_head_replica] ) # [DP_size, num_head, num_head_replica]
group_ranks = group_ranks.permute(0, 2, 1) group_ranks = group_ranks.permute(0, 2, 1)
group_ranks = group_ranks.reshape( group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
-1,
group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
alltoall_group_size = group_ranks.size(-1) // remote_tp_size alltoall_group_size = group_ranks.size(-1) // remote_tp_size
group_ranks = group_ranks.unsqueeze(-1).view( group_ranks = group_ranks.unsqueeze(-1).view(
global_dp_size, num_head_replica, -1, alltoall_group_size global_dp_size, num_head_replica, -1, alltoall_group_size
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size] ) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
group_ranks = group_ranks.reshape(-1, group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0)
alltoall_group_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
local_rank = get_world_group().local_rank local_rank = get_world_group().local_rank
num = next( num = next((i for i, ranks in enumerate(group_ranks) if local_rank in ranks), None)
(i for i, ranks in enumerate(group_ranks) if local_rank in ranks), _P_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=f"p_tp_{num}")
None)
_P_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name=f"p_tp_{num}")
global _MC2 global _MC2
group_ranks = all_ranks.unbind(0) group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
_MC2 = init_model_parallel_group(group_ranks, _MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2")
get_world_group().local_rank,
backend,
group_name="mc2")
# Initialize fine-grained TP process groups on Ascend for four components: # Initialize fine-grained TP process groups on Ascend for four components:
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`) # 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
@@ -108,39 +92,28 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
# 4. MLP: feed-forward network in transformer blocks (`mlp_tensor_parallel_size`) # 4. MLP: feed-forward network in transformer blocks (`mlp_tensor_parallel_size`)
_group_cache = {} _group_cache = {}
def _create_or_get_group(group_size: int, def _create_or_get_group(group_size: int, group_name: str) -> GroupCoordinator:
group_name: str) -> GroupCoordinator:
if group_size is None: if group_size is None:
return None return None
if group_size not in _group_cache: if group_size not in _group_cache:
rank_grid = torch.arange(world_size).reshape(global_pp_size, global_dp_size, global_tp_size)
rank_grid = torch.arange(world_size).reshape(
global_pp_size, global_dp_size, global_tp_size)
num_chunks = global_dp_size // group_size num_chunks = global_dp_size // group_size
group_ranks = [] group_ranks = []
for pp_idx in range(global_pp_size): for pp_idx in range(global_pp_size):
stage_ranks = rank_grid[pp_idx] # (dp, tp) stage_ranks = rank_grid[pp_idx] # (dp, tp)
for chunk in range(num_chunks): for chunk in range(num_chunks):
for tp_idx in range(global_tp_size): for tp_idx in range(global_tp_size):
group = stage_ranks[chunk * group_size:(chunk + 1) * group = stage_ranks[chunk * group_size : (chunk + 1) * group_size, tp_idx].tolist()
group_size, tp_idx].tolist()
group_ranks.append(group) group_ranks.append(group)
pg = init_model_parallel_group(group_ranks, pg = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=group_name)
get_world_group().local_rank,
backend,
group_name=group_name)
_group_cache[group_size] = pg _group_cache[group_size] = pg
return _group_cache[group_size] return _group_cache[group_size]
otp_size = get_ascend_config( otp_size = get_ascend_config().finegrained_tp_config.oproj_tensor_parallel_size
).finegrained_tp_config.oproj_tensor_parallel_size lmhead_tp_size = get_ascend_config().finegrained_tp_config.lmhead_tensor_parallel_size
lmhead_tp_size = get_ascend_config( embedding_tp_size = get_ascend_config().finegrained_tp_config.embedding_tensor_parallel_size
).finegrained_tp_config.lmhead_tensor_parallel_size mlp_tp_size = get_ascend_config().finegrained_tp_config.mlp_tensor_parallel_size
embedding_tp_size = get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size
mlp_tp_size = get_ascend_config(
).finegrained_tp_config.mlp_tensor_parallel_size
global _OTP, _LMTP, _EMBED_TP, _MLP_TP global _OTP, _LMTP, _EMBED_TP, _MLP_TP
@@ -156,10 +129,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
# TODO: Extract and unify the logic across different communication group. # TODO: Extract and unify the logic across different communication group.
flashcomm2_otp_group_ranks = [] flashcomm2_otp_group_ranks = []
if flashcomm2_enable(): if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config( flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
).flashcomm2_oproj_tensor_parallel_size num_fc2_oproj_tensor_parallel_groups: int = global_tp_size // flashcomm2_otp_size
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
flashcomm2_otp_size)
global _FLASHCOMM2_OTP global _FLASHCOMM2_OTP
global _FLASHCOMM2_ODP global _FLASHCOMM2_ODP
@@ -168,8 +139,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if flashcomm2_otp_size > 1: if flashcomm2_otp_size > 1:
odp_group_ranks: list[list[int]] = [ odp_group_ranks: list[list[int]] = [
[] for _ in range(flashcomm2_otp_size * global_dp_size * [] for _ in range(flashcomm2_otp_size * global_dp_size * global_pp_size)
global_pp_size)
] ]
for dp_group_index in range(global_dp_size): for dp_group_index in range(global_dp_size):
for pp_group_index in range(global_pp_size): for pp_group_index in range(global_pp_size):
@@ -186,31 +156,24 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
ranks.append(global_rank) ranks.append(global_rank)
odp_group_index = odp_base_index + j odp_group_index = odp_base_index + j
odp_group_ranks[odp_group_index].append( odp_group_ranks[odp_group_index].append(global_rank)
global_rank)
flashcomm2_otp_group_ranks.append(ranks) flashcomm2_otp_group_ranks.append(ranks)
_FLASHCOMM2_OTP = init_model_parallel_group( _FLASHCOMM2_OTP = init_model_parallel_group(
flashcomm2_otp_group_ranks, flashcomm2_otp_group_ranks, get_world_group().local_rank, backend, group_name="flashcomm2_otp"
get_world_group().local_rank, )
backend,
group_name="flashcomm2_otp")
_FLASHCOMM2_ODP = init_model_parallel_group( _FLASHCOMM2_ODP = init_model_parallel_group(
odp_group_ranks, odp_group_ranks, get_world_group().local_rank, backend, group_name="flashcomm2_odp"
get_world_group().local_rank, )
backend,
group_name="flashcomm2_odp")
def create_shard_weight_group( def create_shard_weight_group(module_tp_group_ranks: None) -> GroupCoordinator:
module_tp_group_ranks: None) -> GroupCoordinator:
# Argument module_tp_group_ranks: The module specific tensor parallel group. # Argument module_tp_group_ranks: The module specific tensor parallel group.
# There are three situations. # There are three situations.
# 1. If it is None, then the TP_size of the specific module is 1 and is replicated linear layer. # 1. If it is None, then the TP_size of the specific module is 1 and is replicated linear layer.
# 2. If it is not None, and the module tp_group is same as the global tp_group. # 2. If it is not None, and the module tp_group is same as the global tp_group.
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp) # 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
group_ranks = [] group_ranks = []
pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape( pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(-1, global_pp_size)
-1, global_pp_size)
if module_tp_group_ranks is None: if module_tp_group_ranks is None:
# If it is None, then the TP_size of this shard weight is 1. # If it is None, then the TP_size of this shard weight is 1.
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0) shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
@@ -219,14 +182,9 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
# combine standard tp group and non-standard tp group to build shard_weight comm_group # combine standard tp group and non-standard tp group to build shard_weight comm_group
module_tp_tanspose_ranks = module_tp_group_ranks.transpose(0, 1) module_tp_tanspose_ranks = module_tp_group_ranks.transpose(0, 1)
G = world_size // (global_pp_size * module_tp_group_ranks.size(1)) G = world_size // (global_pp_size * module_tp_group_ranks.size(1))
shard_weight_group_ranks = torch.stack( shard_weight_group_ranks = torch.stack([t.view(global_pp_size, G) for t in module_tp_tanspose_ranks], dim=1)
[t.view(global_pp_size, G) for t in module_tp_tanspose_ranks],
dim=1)
group_ranks = shard_weight_group_ranks.view(-1, G).tolist() group_ranks = shard_weight_group_ranks.view(-1, G).tolist()
return init_model_parallel_group(group_ranks, return init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="shard_weight")
get_world_group().local_rank,
backend,
group_name="shard_weight")
# Create shard weight group if enabled # Create shard weight group if enabled
if get_ascend_config().layer_sharding is not None: if get_ascend_config().layer_sharding is not None:
@@ -235,8 +193,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if len(flashcomm2_otp_group_ranks) == 0: if len(flashcomm2_otp_group_ranks) == 0:
FC2_group_ranks = None FC2_group_ranks = None
else: else:
FC2_group_ranks = torch.tensor( FC2_group_ranks = torch.tensor(flashcomm2_otp_group_ranks).squeeze(0)
flashcomm2_otp_group_ranks).squeeze(0)
_SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks) _SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks)
elif enable_dsa_cp_with_layer_shard(): elif enable_dsa_cp_with_layer_shard():
# For dsa_cp, all shard layers are replicated. # For dsa_cp, all shard layers are replicated.
@@ -250,40 +207,37 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
global _FC3_QUANT_X global _FC3_QUANT_X
group_ranks = all_ranks.unbind(0) group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
_FC3_QUANT_X = init_model_parallel_group(group_ranks, _FC3_QUANT_X = init_model_parallel_group(
get_world_group().local_rank, group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x"
backend, )
group_name="fc3_quant_x")
def model_parallel_initialized(): def model_parallel_initialized():
return (_MC2 is not None) return _MC2 is not None
def get_mc2_group() -> GroupCoordinator: def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized") assert _MC2 is not None, "mc2 group is not initialized"
return _MC2 return _MC2
def get_mlp_tp_group() -> GroupCoordinator: def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized") assert _MLP_TP is not None, "mlp group is not initialized"
return _MLP_TP return _MLP_TP
def get_otp_group() -> GroupCoordinator: def get_otp_group() -> GroupCoordinator:
assert _OTP is not None, ( assert _OTP is not None, "output tensor parallel group is not initialized"
"output tensor parallel group is not initialized")
return _OTP return _OTP
def get_lmhead_tp_group() -> GroupCoordinator: def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, ( assert _LMTP is not None, "lm head tensor parallel group is not initialized"
"lm head tensor parallel group is not initialized")
return _LMTP return _LMTP
def get_embed_tp_group() -> GroupCoordinator: def get_embed_tp_group() -> GroupCoordinator:
assert _EMBED_TP is not None, ("emtp group is not initialized") assert _EMBED_TP is not None, "emtp group is not initialized"
return _EMBED_TP return _EMBED_TP
@@ -292,25 +246,22 @@ def get_flashcomm2_otp_group() -> GroupCoordinator:
def get_flashcomm2_odp_group() -> GroupCoordinator: def get_flashcomm2_odp_group() -> GroupCoordinator:
assert _FLASHCOMM2_ODP is not None, ( assert _FLASHCOMM2_ODP is not None, "output data parallel group for flashcomm2 is not initialized"
"output data parallel group for flashcomm2 is not initialized")
return _FLASHCOMM2_ODP return _FLASHCOMM2_ODP
def get_shard_weight_group() -> GroupCoordinator: def get_shard_weight_group() -> GroupCoordinator:
assert _SHARD_WEIGHT is not None, ( assert _SHARD_WEIGHT is not None, "output shard weight parallel group for flashcomm2 is not initialized"
"output shard weight parallel group for flashcomm2 is not initialized")
return _SHARD_WEIGHT return _SHARD_WEIGHT
def get_p_tp_group() -> GroupCoordinator: def get_p_tp_group() -> GroupCoordinator:
assert _P_TP is not None, ( assert _P_TP is not None, "distributed prefill tensor parallel group is not initialized"
"distributed prefill tensor parallel group is not initialized")
return _P_TP return _P_TP
def get_fc3_quant_x_group() -> GroupCoordinator: def get_fc3_quant_x_group() -> GroupCoordinator:
assert _FC3_QUANT_X is not None, ("fc3 quant x group is not initialized") assert _FC3_QUANT_X is not None, "fc3 quant x group is not initialized"
return _FC3_QUANT_X return _FC3_QUANT_X
@@ -346,14 +297,12 @@ def destroy_ascend_model_parallel():
_P_TP = None _P_TP = None
global _FLASHCOMM2_OTP global _FLASHCOMM2_OTP
if _FLASHCOMM2_OTP and get_ascend_config( if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
).flashcomm2_oproj_tensor_parallel_size != 1:
_FLASHCOMM2_OTP.destroy() _FLASHCOMM2_OTP.destroy()
_FLASHCOMM2_OTP = None _FLASHCOMM2_OTP = None
global _FLASHCOMM2_ODP global _FLASHCOMM2_ODP
if _FLASHCOMM2_ODP and get_ascend_config( if _FLASHCOMM2_ODP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
).flashcomm2_oproj_tensor_parallel_size != 1:
_FLASHCOMM2_ODP.destroy() _FLASHCOMM2_ODP.destroy()
_FLASHCOMM2_ODP = None _FLASHCOMM2_ODP = None

View File

@@ -1,5 +1,3 @@
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
@@ -8,7 +6,9 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.distributed.parallel_state import get_fc3_quant_x_group from vllm_ascend.distributed.parallel_state import get_fc3_quant_x_group
def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor: def fc3_all_gather_and_maybe_unpad_impl(
x: torch.Tensor,
) -> torch.Tensor:
try: try:
forward_context = get_forward_context() forward_context = get_forward_context()
except AssertionError: except AssertionError:
@@ -22,9 +22,7 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
else: else:
# unpad # unpad
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
device=x.device,
dtype=x.dtype)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
offset = 0 offset = 0
@@ -37,19 +35,13 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
return x return x
def all_gather_async(input: torch.Tensor, def all_gather_async(
group: GroupCoordinator, input: torch.Tensor, group: GroupCoordinator, output: torch.Tensor | None = None, async_op: bool = True
output: Optional[torch.Tensor] = None, ):
async_op: bool = True):
if group.world_size == 1: if group.world_size == 1:
return input, None return input, None
if output is None: if output is None:
input_size = input.size() input_size = input.size()
output_size = (input_size[0] * group.world_size,) + input_size[1:] output_size = (input_size[0] * group.world_size,) + input_size[1:]
output = torch.empty(output_size, output = torch.empty(output_size, dtype=input.dtype, device=input.device)
dtype=input.dtype, return output, dist.all_gather_into_tensor(output, input, group=group.device_group, async_op=async_op)
device=input.device)
return output, dist.all_gather_into_tensor(output,
input,
group=group.device_group,
async_op=async_op)

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import Callable, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -17,39 +17,38 @@ def dispose_tensor(x: torch.Tensor):
@dataclass @dataclass
class LayerMetadata: class LayerMetadata:
"""Metadata for a layer. """Metadata for a layer."""
"""
layer_idx: int # The index of the layer. layer_idx: int # The index of the layer.
layer: LinearBase # The layer object. layer: LinearBase # The layer object.
post_method: Callable[[ post_method: Callable[[torch.nn.Module], None] # The `process_weights_after_loading` method from the quant method.
torch.nn.Module
], None] # The `process_weights_after_loading` method from the quant method.
weight: torch.Tensor # The weight tensor. weight: torch.Tensor # The weight tensor.
window_idx: int # The index of the window. window_idx: int # The index of the window.
@dataclass @dataclass
class ShardWindowMetadata: class ShardWindowMetadata:
"""Metadata for a shard window. """Metadata for a shard window."""
"""
weight: torch.Tensor # The weight tensor to be shard by layers. weight: torch.Tensor # The weight tensor to be shard by layers.
data_layer_idx: int # The index of the layer this window's weight is equal to. data_layer_idx: int # The index of the layer this window's weight is equal to.
work: Optional[torch.distributed.Work] # The asynchronous broadcast work. work: torch.distributed.Work | None # The asynchronous broadcast work.
@dataclass @dataclass
class SeriesMetadata: class SeriesMetadata:
"""Metadata for a weight shard series. """Metadata for a weight shard series."""
"""
group: GroupCoordinator group: GroupCoordinator
start_layer: int start_layer: int
end_layer: int end_layer: int
num_layers: int num_layers: int
prefetch_step: int prefetch_step: int
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor. dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix.
# All the layers in the series share the same dummy weight tensor.
layers: list[LayerMetadata] layers: list[LayerMetadata]
shard_windows: list[ shard_windows: list[ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1),
ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. # as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
window_offset: int # The index of the window for the next coming layer. window_offset: int # The index of the window for the next coming layer.
def is_source(self, layer_idx) -> bool: def is_source(self, layer_idx) -> bool:
@@ -63,9 +62,9 @@ class SeriesMetadata:
self.layers.sort(key=lambda x: x.layer_idx) self.layers.sort(key=lambda x: x.layer_idx)
self.num_layers = len(self.layers) self.num_layers = len(self.layers)
assert self.num_layers > 0, "No layers in the series" assert self.num_layers > 0, "No layers in the series"
assert self.prefetch_step >= 0 and self.prefetch_step <= max( assert self.prefetch_step >= 0 and self.prefetch_step <= max(0, self.num_layers - 2), (
0, self.num_layers - "prefetch_step must be in [0, num_layers - 2]"
2), "prefetch_step must be in [0, num_layers - 2]" )
self.start_layer = self.layers[0].layer_idx self.start_layer = self.layers[0].layer_idx
self.end_layer = self.layers[-1].layer_idx + 1 self.end_layer = self.layers[-1].layer_idx + 1
@@ -73,25 +72,27 @@ class SeriesMetadata:
layer = self.layers[layer_idx - self.start_layer] layer = self.layers[layer_idx - self.start_layer]
assert layer.layer_idx == layer_idx, "layer_idx must be consecutive" assert layer.layer_idx == layer_idx, "layer_idx must be consecutive"
is_source = self.is_source(layer_idx) is_source = self.is_source(layer_idx)
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight. # If the weight uses dummy weight, make a copy temporary such that the post method call
# won't affect other layers which also uses dummy weight.
if not is_source: if not is_source:
layer.weight.set_(torch.empty_like(self.dummy_weight)) layer.weight.set_(torch.empty_like(self.dummy_weight))
# Broadcast to get the true weight. # Broadcast to get the true weight.
dist.broadcast(layer.weight, dist.broadcast(
src=self.group.ranks[layer_idx % layer.weight, src=self.group.ranks[layer_idx % self.group.world_size], group=self.group.device_group
self.group.world_size], )
group=self.group.device_group)
# Call `process_weights_after_loading` from the quant method. # Call `process_weights_after_loading` from the quant method.
layer.post_method(layer.layer) layer.post_method(layer.layer)
step = layer_idx - self.start_layer step = layer_idx - self.start_layer
if step < self.prefetch_step: if step < self.prefetch_step:
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights. # Build the windows for the first `prefetch_step` layers. The weights can be used
# for the first `prefetch_step` layers in `forward()`, so also clone the weights.
self.shard_windows.append( self.shard_windows.append(
ShardWindowMetadata( ShardWindowMetadata(
weight=layer.weight.clone().detach(), weight=layer.weight.clone().detach(),
data_layer_idx=layer_idx, data_layer_idx=layer_idx,
work=None, work=None,
)) )
)
layer.window_idx = step layer.window_idx = step
# When the layer not intended to be stored in this device, link to the corresponding window's tensor. # When the layer not intended to be stored in this device, link to the corresponding window's tensor.
if not is_source: if not is_source:
@@ -104,7 +105,8 @@ class SeriesMetadata:
weight=torch.empty_like(layer.weight), weight=torch.empty_like(layer.weight),
data_layer_idx=-1, data_layer_idx=-1,
work=None, work=None,
)) )
)
# When the layer not intended to be stored in this device, dispose the tensor. # When the layer not intended to be stored in this device, dispose the tensor.
if not is_source: if not is_source:
dispose_tensor(layer.weight) dispose_tensor(layer.weight)
@@ -113,8 +115,7 @@ class SeriesMetadata:
def reach_layer(self, layer_idx: int): def reach_layer(self, layer_idx: int):
# The index of the layer to be prefetched. # The index of the layer to be prefetched.
next_layer_idx = (layer_idx + self.prefetch_step next_layer_idx = (layer_idx + self.prefetch_step) % self.num_layers + self.start_layer
) % self.num_layers + self.start_layer
next_layer = self.layers[next_layer_idx - self.start_layer] next_layer = self.layers[next_layer_idx - self.start_layer]
# The index of the window to store the weight for the coming layer. # The index of the window to store the weight for the coming layer.
next_layer.window_idx = self.window_offset next_layer.window_idx = self.window_offset
@@ -123,8 +124,7 @@ class SeriesMetadata:
if not self.is_source(next_layer_idx): if not self.is_source(next_layer_idx):
next_layer.weight.set_(window.weight) next_layer.weight.set_(window.weight)
# Update `window_offset` by rolling one step. # Update `window_offset` by rolling one step.
self.window_offset = (self.window_offset + 1) % (self.prefetch_step + self.window_offset = (self.window_offset + 1) % (self.prefetch_step + 1)
1)
assert window.data_layer_idx != next_layer_idx assert window.data_layer_idx != next_layer_idx
window.data_layer_idx = next_layer_idx window.data_layer_idx = next_layer_idx
# Start asynchronous broadcast work. # Start asynchronous broadcast work.
@@ -132,13 +132,13 @@ class SeriesMetadata:
next_layer.weight, next_layer.weight,
src=self.group.ranks[next_layer_idx % self.group.world_size], src=self.group.ranks[next_layer_idx % self.group.world_size],
group=self.group.device_group, group=self.group.device_group,
async_op=True) async_op=True,
)
def wait_weight(self, layer_idx: int): def wait_weight(self, layer_idx: int):
# Find the asynchronous broadcast work and wait for it. # Find the asynchronous broadcast work and wait for it.
assert self.shard_windows assert self.shard_windows
window = self.shard_windows[self.layers[layer_idx - window = self.shard_windows[self.layers[layer_idx - self.start_layer].window_idx]
self.start_layer].window_idx]
# Make sure the data in the corresponding shard window is for the current layer. # Make sure the data in the corresponding shard window is for the current layer.
assert window.data_layer_idx == layer_idx assert window.data_layer_idx == layer_idx
if window.work is not None: if window.work is not None:
@@ -148,8 +148,8 @@ class SeriesMetadata:
@dataclass @dataclass
class LayerExternalMetadata: class LayerExternalMetadata:
"""External metadata for a layer. """External metadata for a layer."""
"""
series: SeriesMetadata series: SeriesMetadata
layer_idx: int layer_idx: int
@@ -159,9 +159,7 @@ _series_dict: dict[str, SeriesMetadata] = {}
_layer_external_dict: dict[int, LayerExternalMetadata] = {} _layer_external_dict: dict[int, LayerExternalMetadata] = {}
def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, layer_idx: int) -> Callable:
layer_idx: int) -> Callable:
def wrapped_forward(*args, **kwargs): def wrapped_forward(*args, **kwargs):
# Wait for the weight. # Wait for the weight.
series.wait_weight(layer_idx) series.wait_weight(layer_idx)
@@ -173,23 +171,32 @@ def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
""" """
Register linear layers into a shard storage series. Register linear layers into a shard storage series.
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices. In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series.
All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer
is stored on device (i % n), where n is the number of devices.
After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)` on any layer of this series to complete the initialization. After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)`
on any layer of this series to complete the initialization.
During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series. During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)`
for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages
asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will
trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula: Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
- start_layer is the index of the first layer in the series (inclusive). - start_layer is the index of the first layer in the series (inclusive).
- end_layer is the index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer). - end_layer is the index of the last layer in the series (exclusive). Thus, the series includes all layers with
indices in the range [start_layer, end_layer).
- total_layers = end_layer - start_layer - total_layers = end_layer - start_layer
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer - prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers will be created for this series. To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers
will be created for this series.
Arguments: Arguments:
series_name: This name identifies which series this layer belongs to. series_name: This name identifies which series this layer belongs to.
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series. group: The group coordinator for handling asynchronous communications. It is recommended to create a new group
coordinator for each new series.
layer: The linear layer object to register. layer: The linear layer object to register.
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases. prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
""" """
@@ -224,7 +231,8 @@ def register_layer_to_shard_weight_series(
post_method=layer.quant_method.process_weights_after_loading, post_method=layer.quant_method.process_weights_after_loading,
weight=layer.weight, weight=layer.weight,
window_idx=-1, window_idx=-1,
)) )
)
# Discard the original `process_weights_after_loading` method such that it won't be called by others. # Discard the original `process_weights_after_loading` method such that it won't be called by others.
layer.quant_method.process_weights_after_loading = lambda layer: None layer.quant_method.process_weights_after_loading = lambda layer: None
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading. # When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
@@ -257,6 +265,7 @@ def wait_layer_for_shard_weight_series(layer: LinearBase):
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_current_model_num_hidden_layers() -> int: def get_current_model_num_hidden_layers() -> int:
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
return vllm_config.model_config.get_total_num_hidden_layers() return vllm_config.model_config.get_total_num_hidden_layers()
@@ -268,10 +277,11 @@ def is_hidden_layer(layer: LinearBase) -> bool:
def register_all_layers_to_shard_weight_series( def register_all_layers_to_shard_weight_series(
layer_sharding: List[LinearBase], ): layer_sharding: list[LinearBase],
for curr_layer in (layer_sharding or []): ):
for curr_layer in layer_sharding or []:
if is_hidden_layer(curr_layer): if is_hidden_layer(curr_layer):
layer_name = curr_layer.prefix.split('.')[-1] layer_name = curr_layer.prefix.split(".")[-1]
register_layer_to_shard_weight_series( register_layer_to_shard_weight_series(
series_name=layer_name, series_name=layer_name,
group=get_shard_weight_group(), group=get_shard_weight_group(),

View File

@@ -20,19 +20,23 @@ AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear,
AscendRowParallelLinear and AscendColumnParallelLinear. AscendRowParallelLinear and AscendColumnParallelLinear.
""" """
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import divide from vllm.distributed import divide
from vllm.model_executor.layers.linear import ( # noqa from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, WEIGHT_LOADER_V2_SUPPORTED,
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase, ColumnParallelLinear,
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) LinearBase,
from vllm.model_executor.layers.quantization.base_config import \ MergedColumnParallelLinear,
QuantizationConfig QKVParallelLinear,
QuantizeMethodBase,
ReplicatedLinear,
RowParallelLinear,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
@@ -50,14 +54,13 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group # TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
class AscendLinearBase(LinearBase): class AscendLinearBase(LinearBase):
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
output_size: int, output_size: int,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: torch.dtype | None = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
@@ -75,11 +78,9 @@ class AscendLinearBase(LinearBase):
self.quant_config = quant_config self.quant_config = quant_config
self.prefix = prefix self.prefix = prefix
if quant_config is None: if quant_config is None:
self.quant_method: Optional[ self.quant_method: QuantizeMethodBase | None = AscendUnquantizedLinearMethod()
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
else: else:
self.quant_method = quant_config.get_quant_method(self, self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
prefix=prefix)
self.return_bias = return_bias self.return_bias = return_bias
self.disable_tp = disable_tp self.disable_tp = disable_tp
@@ -100,11 +101,11 @@ class AscendQKVParallelLinear(QKVParallelLinear):
hidden_size: int, hidden_size: int,
head_size: int, head_size: int,
total_num_heads: int, total_num_heads: int,
total_num_kv_heads: Optional[int] = None, total_num_kv_heads: int | None = None,
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: torch.dtype | None = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
@@ -112,9 +113,9 @@ class AscendQKVParallelLinear(QKVParallelLinear):
v_head_size: int | None = None, v_head_size: int | None = None,
): ):
self.v_head_size = v_head_size if v_head_size is not None else head_size self.v_head_size = v_head_size if v_head_size is not None else head_size
self.custom_op, _, tp_size = get_parallel_op(disable_tp, prefix, self, self.custom_op, _, tp_size = get_parallel_op(disable_tp, prefix, self, "column")
"column") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group # linear of vllm supports custom comm group
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
self.total_num_heads = total_num_heads self.total_num_heads = total_num_heads
@@ -125,20 +126,19 @@ class AscendQKVParallelLinear(QKVParallelLinear):
self.num_heads = divide(self.total_num_heads, tp_size) self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads: if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1 self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size, self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
self.total_num_kv_heads)
else: else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1 self.num_kv_head_replicas = 1
input_size = self.hidden_size input_size = self.hidden_size
output_size = (self.num_heads + output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [ self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj self.num_kv_heads * self.head_size * tp_size, # v_proj
] ]
AscendColumnParallelLinear.__init__(self, AscendColumnParallelLinear.__init__(
self,
input_size=input_size, input_size=input_size,
output_size=output_size, output_size=output_size,
bias=bias, bias=bias,
@@ -148,12 +148,13 @@ class AscendQKVParallelLinear(QKVParallelLinear):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
disable_tp=disable_tp) disable_tp=disable_tp,
)
def forward( def forward(
self, self,
input_, input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.custom_op is not None: if self.custom_op is not None:
return self.custom_op.apply(input_) return self.custom_op.apply(input_)
@@ -178,20 +179,20 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
bias: bool = True, bias: bool = True,
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: torch.dtype | None = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
disable_tp: bool = False, disable_tp: bool = False,
): ):
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(disable_tp, prefix, self, "column")
disable_tp, prefix, self, "column") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group # linear of vllm supports custom comm group
self.output_sizes = output_sizes self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0 assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
for output_size in output_sizes) AscendColumnParallelLinear.__init__(
AscendColumnParallelLinear.__init__(self, self,
input_size=input_size, input_size=input_size,
output_size=sum(output_sizes), output_size=sum(output_sizes),
bias=bias, bias=bias,
@@ -201,12 +202,13 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
disable_tp=disable_tp) disable_tp=disable_tp,
)
def forward( def forward(
self, self,
input_, input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.custom_op is not None: if self.custom_op is not None:
return self.custom_op.apply(input_) return self.custom_op.apply(input_)
@@ -229,9 +231,9 @@ class AscendRowParallelLinear(RowParallelLinear):
bias: bool = True, bias: bool = True,
input_is_parallel: bool = True, input_is_parallel: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: torch.dtype | None = None,
reduce_results: bool = True, reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
@@ -247,15 +249,16 @@ class AscendRowParallelLinear(RowParallelLinear):
self.unique_prefix = unique_prefix self.unique_prefix = unique_prefix
compilation_config.static_forward_context[unique_prefix] = self compilation_config.static_forward_context[unique_prefix] = self
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(disable_tp, prefix, self, "row")
disable_tp, prefix, self, "row") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group # linear of vllm supports custom comm group
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self, AscendLinearBase.__init__(
self,
input_size, input_size,
output_size, output_size,
skip_bias_add, skip_bias_add,
@@ -263,7 +266,8 @@ class AscendRowParallelLinear(RowParallelLinear):
quant_config, quant_config,
prefix, prefix,
return_bias=return_bias, return_bias=return_bias,
disable_tp=disable_tp) disable_tp=disable_tp,
)
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results self.reduce_results = reduce_results
@@ -277,19 +281,23 @@ class AscendRowParallelLinear(RowParallelLinear):
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=( weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__ self.weight_loader_v2
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the results can lead to incorrect results")
"results can lead to incorrect results")
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
torch.empty(self.output_size, dtype=params_dtype)) set_weight_attrs(
set_weight_attrs(self.bias, { self.bias,
{
"output_dim": 0, "output_dim": 0,
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
)
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
@@ -300,7 +308,7 @@ class AscendRowParallelLinear(RowParallelLinear):
self, self,
input_, input_,
**kwargs, **kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.custom_op is not None: if self.custom_op is not None:
return self.custom_op.apply(input_) return self.custom_op.apply(input_)
@@ -321,28 +329,27 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
bias: bool = True, bias: bool = True,
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: torch.dtype | None = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: QuantizationConfig | None = None,
output_sizes: Optional[list[int]] = None, output_sizes: list[int] | None = None,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
disable_tp: bool = False, disable_tp: bool = False,
): ):
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( #
disable_tp, prefix, self, "column") self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(disable_tp, prefix, self, "column")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after
# linear of vllm supports custom comm group
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition] self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition. # If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"): if hasattr(self, "output_sizes"):
self.output_partition_sizes = [ self.output_partition_sizes = [divide(output_size, self.tp_size) for output_size in self.output_sizes]
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
AscendLinearBase.__init__(self, AscendLinearBase.__init__(
self,
input_size, input_size,
output_size, output_size,
skip_bias_add, skip_bias_add,
@@ -350,7 +357,8 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
quant_config, quant_config,
prefix, prefix,
return_bias=return_bias, return_bias=return_bias,
disable_tp=disable_tp) disable_tp=disable_tp,
)
self.gather_output = gather_output self.gather_output = gather_output
@@ -366,16 +374,20 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=( weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__ self.weight_loader_v2
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=params_dtype))
torch.empty(self.output_size_per_partition, set_weight_attrs(
dtype=params_dtype)) self.bias,
set_weight_attrs(self.bias, { {
"output_dim": 0, "output_dim": 0,
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
)
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
@@ -385,7 +397,7 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
def forward( def forward(
self, self,
input_, input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.custom_op is not None: if self.custom_op is not None:
return self.custom_op.apply(input_) return self.custom_op.apply(input_)
@@ -414,8 +426,8 @@ class AscendReplicatedLinear(ReplicatedLinear):
output_size: int, output_size: int,
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: torch.dtype | None = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
@@ -428,7 +440,8 @@ class AscendReplicatedLinear(ReplicatedLinear):
else: else:
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self, AscendLinearBase.__init__(
self,
input_size, input_size,
output_size, output_size,
skip_bias_add, skip_bias_add,
@@ -436,24 +449,30 @@ class AscendReplicatedLinear(ReplicatedLinear):
quant_config, quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
disable_tp=disable_tp) disable_tp=disable_tp,
)
# All the linear layer supports quant method. # All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights(self, self.quant_method.create_weights(
self.input_size, [self.output_size], self,
self.input_size,
[self.output_size],
self.input_size, self.input_size,
self.output_size, self.output_size,
self.params_dtype, self.params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader,
)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(torch.empty(self.output_size, dtype=self.params_dtype))
torch.empty(self.output_size, dtype=self.params_dtype)) set_weight_attrs(
set_weight_attrs(self.bias, { self.bias,
{
"output_dim": 0, "output_dim": 0,
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) },
)
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
@@ -463,7 +482,7 @@ class AscendReplicatedLinear(ReplicatedLinear):
def forward( def forward(
self, self,
input_, input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.custom_op is not None: if self.custom_op is not None:
return self.custom_op.apply(input_) return self.custom_op.apply(input_)

View File

@@ -31,16 +31,18 @@ CustomLinearOp
└── CustomReplicatedOp └── CustomReplicatedOp
How to extend a new linear op? Taking column parallel op as an example: How to extend a new linear op? Taking column parallel op as an example:
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp 1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method 2. [Optional] The default communication group is the TP group. If a custom communication group is needed,
override the comm_group method
3. Override the apply method according to requirements, which will replace the original linear.forward 3. Override the apply method according to requirements, which will replace the original linear.forward
4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments 4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op. prefix and configuration judgments
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in
get_row_parallel_op.
""" """
import re import re
from functools import lru_cache from functools import lru_cache
from types import SimpleNamespace from types import SimpleNamespace
from typing import Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -49,27 +51,37 @@ import torch_npu
from torch import nn from torch import nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.distributed import (split_tensor_along_last_dim, from vllm.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter) tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import get_tp_group from vllm.distributed.parallel_state import get_tp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm_ascend import envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, from vllm_ascend.distributed.parallel_state import (
get_flashcomm2_odp_group,
get_flashcomm2_otp_group, get_flashcomm2_otp_group,
get_mlp_tp_group, get_mlp_tp_group,
get_otp_group) get_otp_group,
)
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_sp, flashcomm2_enable, from vllm_ascend.utils import (
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
enable_sp,
flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids, get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable, get_weight_prefetch_method,
oproj_tp_enable, shared_expert_dp_enabled, matmul_allreduce_enable,
get_weight_prefetch_method) mlp_tp_enable,
oproj_tp_enable,
shared_expert_dp_enabled,
)
class CustomLinearOp: class CustomLinearOp:
def __init__(self, layer): def __init__(self, layer):
self.layer = layer self.layer = layer
self.bias = None self.bias = None
@@ -112,7 +124,6 @@ class CustomLinearOp:
class CustomColumnParallelOp(CustomLinearOp): class CustomColumnParallelOp(CustomLinearOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
self.gather_output = None self.gather_output = None
@@ -123,7 +134,6 @@ class CustomColumnParallelOp(CustomLinearOp):
class CustomRowParallelOp(CustomLinearOp): class CustomRowParallelOp(CustomLinearOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
self.reduce_results = None self.reduce_results = None
@@ -140,7 +150,9 @@ class CustomRowParallelOp(CustomLinearOp):
output, output_bias = self.apply_impl(input_) output, output_bias = self.apply_impl(input_)
weight_prefetch_method = get_weight_prefetch_method() weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method: if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_GATE_UP, output, self.prefix) weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(
weight_prefetch_method.MLP_GATE_UP, output, self.prefix
)
if not self.return_bias: if not self.return_bias:
return output return output
@@ -148,7 +160,6 @@ class CustomRowParallelOp(CustomLinearOp):
class CustomReplicatedOp(CustomLinearOp): class CustomReplicatedOp(CustomLinearOp):
def apply_impl(self, input_): def apply_impl(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
@@ -160,7 +171,6 @@ class CustomReplicatedOp(CustomLinearOp):
class MLPColumnParallelOp(CustomColumnParallelOp): class MLPColumnParallelOp(CustomColumnParallelOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
@@ -171,7 +181,7 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
def apply_impl( def apply_impl(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None assert self.quant_method is not None
@@ -183,7 +193,6 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
class MLPRowParallelOp(CustomRowParallelOp): class MLPRowParallelOp(CustomRowParallelOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
@@ -191,22 +200,16 @@ class MLPRowParallelOp(CustomRowParallelOp):
def comm_group(self): def comm_group(self):
return get_mlp_tp_group() return get_mlp_tp_group()
def apply_impl( def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
splitted_input = split_tensor_along_last_dim( splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous() input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.layer.bias
or self.skip_bias_add) else self.layer.bias output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
output = self.comm_group.reduce_scatter(output_parallel, 0) output = self.comm_group.reduce_scatter(output_parallel, 0)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
@@ -214,7 +217,6 @@ class MLPRowParallelOp(CustomRowParallelOp):
class OProjRowParallelOp(CustomRowParallelOp): class OProjRowParallelOp(CustomRowParallelOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
@@ -225,13 +227,11 @@ class OProjRowParallelOp(CustomRowParallelOp):
def apply_impl( def apply_impl(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
splitted_input = split_tensor_along_last_dim( splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous() input_parallel = splitted_input[self.tp_rank].contiguous()
# Prepare tensors for all-to-all communication # Prepare tensors for all-to-all communication
@@ -241,27 +241,19 @@ class OProjRowParallelOp(CustomRowParallelOp):
# Reshape tensor for efficient cross-device transfer: # Reshape tensor for efficient cross-device transfer:
# [batch, dim] -> [tp_size, batch, chunk] -> flattened # [batch, dim] -> [tp_size, batch, chunk] -> flattened
send_buf = (input_parallel.reshape(-1, send_buf = input_parallel.reshape(-1, self.tp_size, chunk_size).transpose(0, 1).contiguous().view(-1)
self.tp_size, chunk_size).transpose(
0, 1).contiguous().view(-1))
# Create receive buffer # Create receive buffer
recv_buf = torch.empty(total_batch_size * chunk_size, recv_buf = torch.empty(total_batch_size * chunk_size, dtype=input_parallel.dtype, device=input_parallel.device)
dtype=input_parallel.dtype,
device=input_parallel.device)
# Perform all-to-all communication # Perform all-to-all communication
dist.all_to_all_single(recv_buf, dist.all_to_all_single(recv_buf, send_buf, group=self.comm_group.device_group)
send_buf,
group=self.comm_group.device_group)
input_parallel = recv_buf.view(total_batch_size, chunk_size) input_parallel = recv_buf.view(total_batch_size, chunk_size)
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1 # Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None assert self.quant_method is not None
output_parallel = self.quant_method.apply(self.layer, output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
input_parallel,
bias=bias_)
# otp-specific: Combine partial results across devices # otp-specific: Combine partial results across devices
output = self.comm_group.reduce_scatter(output_parallel, dim=0) output = self.comm_group.reduce_scatter(output_parallel, dim=0)
@@ -278,14 +270,12 @@ class OProjRowParallelOp(CustomRowParallelOp):
class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
self.odp_group = get_flashcomm2_odp_group() self.odp_group = get_flashcomm2_odp_group()
self.odp_size = self.odp_group.world_size self.odp_size = self.odp_group.world_size
self.otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size self.otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids( self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(get_tp_group().world_size)
get_tp_group().world_size)
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu() self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
self.layer._quant_comm_config = {} self.layer._quant_comm_config = {}
@@ -308,7 +298,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
def apply_impl( def apply_impl(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
"""Linear layer for Flashcomm2. """Linear layer for Flashcomm2.
Input.ahspe = [batchsize*seqlength, headnum*headdim/TP] Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize] Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
@@ -318,22 +308,18 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
input_parallel = input_ input_parallel = input_
else: else:
tp_rank = self.tp_rank tp_rank = self.tp_rank
splitted_input = split_tensor_along_last_dim( splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous() input_parallel = splitted_input[tp_rank].contiguous()
# padding for all-to-all # padding for all-to-all
forward_context = get_forward_context() forward_context = get_forward_context()
num_padding_tokens = forward_context.pad_size num_padding_tokens = forward_context.pad_size
if num_padding_tokens > 0: if num_padding_tokens > 0:
input_parallel = nn.functional.pad(input_parallel, input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens))
(0, 0, 0, num_padding_tokens))
def otp_maybe_quant_comm(x): def otp_maybe_quant_comm(x):
# Reorganize the tensor so that the batch id and rank id correspond to each other. # Reorganize the tensor so that the batch id and rank id correspond to each other.
chunk_num = len(self.reorgnized_batch_ids) * len( chunk_num = len(self.reorgnized_batch_ids) * len(self.reorgnized_batch_ids[0])
self.reorgnized_batch_ids[0])
batch_size = x.size(0) batch_size = x.size(0)
assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})" assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})"
@@ -352,26 +338,19 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
total_intermediate_size = local_intermediate_size * all2all_tp_size total_intermediate_size = local_intermediate_size * all2all_tp_size
# Create receive buffer # Create receive buffer
recv_buf = torch.empty(total_intermediate_size * chunk_size, recv_buf = torch.empty(total_intermediate_size * chunk_size, dtype=x.dtype, device=x.device)
dtype=x.dtype,
device=x.device)
# Perform all-to-all communication # Perform all-to-all communication
dist.all_to_all_single(recv_buf, dist.all_to_all_single(recv_buf, send_buf, group=self.odp_group.device_group)
send_buf,
group=self.odp_group.device_group)
return recv_buf.view(all2all_tp_size, chunk_size, return recv_buf.view(all2all_tp_size, chunk_size, -1).transpose(0, 1).reshape(chunk_size, -1)
-1).transpose(0, 1).reshape(chunk_size, -1)
if not hasattr(self, "_quant_comm_config"): if not hasattr(self, "_quant_comm_config"):
self.layer._quant_comm_config = {} self.layer._quant_comm_config = {}
self.layer._quant_comm_config[ self.layer._quant_comm_config["communication_fn"] = otp_maybe_quant_comm
"communication_fn"] = otp_maybe_quant_comm actual_quant_method = getattr(self.quant_method, "quant_method", self.quant_method)
actual_quant_method = getattr(self.quant_method, 'quant_method', from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
self.quant_method)
from vllm_ascend.quantization.methods.w8a8_static import \
AscendW8A8LinearMethod
if not isinstance(actual_quant_method, AscendW8A8LinearMethod): if not isinstance(actual_quant_method, AscendW8A8LinearMethod):
# Check if w8a8 quantization is enabled. If not, communicate immediately. # Check if w8a8 quantization is enabled. If not, communicate immediately.
input_parallel = otp_maybe_quant_comm(input_parallel) input_parallel = otp_maybe_quant_comm(input_parallel)
@@ -382,9 +361,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self.layer, output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
input_parallel,
bias=bias_)
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate] # output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
if self.tp_size > 1: if self.tp_size > 1:
# flashcomm2 with reduce-scatter # flashcomm2 with reduce-scatter
@@ -408,8 +385,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
self.input_is_parallel = self.layer.input_is_parallel self.input_is_parallel = self.layer.input_is_parallel
self.input_size_per_partition = self.layer.input_size_per_partition self.input_size_per_partition = self.layer.input_size_per_partition
if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
flashcomm2_oshard_manager.register_layer(self.layer, flashcomm2_oshard_manager.register_layer(self.layer, prefetch_step=1)
prefetch_step=1)
class MatmulAllreduceRowParallelOp(CustomRowParallelOp): class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
@@ -419,28 +395,22 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
super().__init__(layer) super().__init__(layer)
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group) self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
def apply_impl( def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
splitted_input = split_tensor_along_last_dim( splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous() input_parallel = splitted_input[self.tp_rank].contiguous()
"""Calculate the output tensor of forward by considering """Calculate the output tensor of forward by considering
fusing communication and computation.""" fusing communication and computation."""
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output = torch_npu.npu_mm_all_reduce_base(input_parallel, output = torch_npu.npu_mm_all_reduce_base(
self.layer.weight.t(), input_parallel, self.layer.weight.t(), self.hcomm_info, bias=bias_
self.hcomm_info, )
bias=bias_)
else: else:
assert self.quant_method is not None assert self.quant_method is not None
output = self.quant_method.apply(self.layer, output = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
input_parallel,
bias=bias_)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
@@ -454,18 +424,14 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
rank = torch.distributed.get_rank(group) rank = torch.distributed.get_rank(group)
if torch.__version__ > "2.0": if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank) global_rank = torch.distributed.get_global_rank(group, rank)
cls._HCOMM_INFO = group._get_backend( cls._HCOMM_INFO = group._get_backend(torch.device("npu")).get_hccl_comm_name(global_rank)
torch.device("npu")).get_hccl_comm_name(global_rank)
else: else:
cls._HCOMM_INFO = group.get_hccl_comm_name(rank) cls._HCOMM_INFO = group.get_hccl_comm_name(rank)
return cls._HCOMM_INFO return cls._HCOMM_INFO
class SequenceColumnParallelOp(CustomColumnParallelOp): class SequenceColumnParallelOp(CustomColumnParallelOp):
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism. """Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and Implemented multiple optimization projects for dense models, such as FlashComm and
@@ -490,13 +456,10 @@ class SequenceColumnParallelOp(CustomColumnParallelOp):
class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp): class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
def apply_impl( def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Column-parallel linear with FlashComm2 OShard optimization.""" """Column-parallel linear with FlashComm2 OShard optimization."""
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
@@ -505,12 +468,10 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
assert self.quant_method is not None assert self.quant_method is not None
if enable_sp(): if enable_sp():
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
input_, True)
# Trigger async broadcast before matmul to overlap communication. # Trigger async broadcast before matmul to overlap communication.
flashcomm2_oshard_manager.trigger_broadcast_for_layer( flashcomm2_oshard_manager.trigger_broadcast_for_layer(self.layer.prefix)
self.layer.prefix)
output_parallel = self.quant_method.apply(self.layer, input_, bias) output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output and self.tp_size > 1: if self.gather_output and self.tp_size > 1:
@@ -523,14 +484,11 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
class SequenceRowParallelOp(CustomRowParallelOp): class SequenceRowParallelOp(CustomRowParallelOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
self.unique_prefix = None self.unique_prefix = None
def apply_impl( def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism. """Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and Implemented multiple optimization projects for dense models, such as FlashComm and
@@ -540,26 +498,21 @@ class SequenceRowParallelOp(CustomRowParallelOp):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
splitted_input = split_tensor_along_last_dim( splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous() input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.tp_size == 1 or not self.reduce_results: if self.tp_size == 1 or not self.reduce_results:
output = self.quant_method.apply(self.layer, output = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
input_parallel,
bias=bias_)
else: else:
output = torch.ops.vllm.matmul_and_reduce(input_parallel, output = torch.ops.vllm.matmul_and_reduce(input_parallel, self.unique_prefix)
self.unique_prefix)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
def matmul_and_reduce(self, input_parallel: torch.Tensor, def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Parameter | None) -> torch.Tensor:
bias_: Optional[Parameter]) -> torch.Tensor:
assert self.quant_method is not None assert self.quant_method is not None
try: try:
forward_context = get_forward_context() forward_context = get_forward_context()
@@ -572,29 +525,24 @@ class SequenceRowParallelOp(CustomRowParallelOp):
x = input_parallel x = input_parallel
if not sp_enabled: if not sp_enabled:
output_parallel = self.layer.quant_method.apply(self.layer, output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
x,
bias=bias_)
return tensor_model_parallel_all_reduce(output_parallel) return tensor_model_parallel_all_reduce(output_parallel)
pad_size = forward_context.pad_size pad_size = forward_context.pad_size
if pad_size > 0 and not (enable_dsa_cp() if pad_size > 0 and not (enable_dsa_cp() and "o_proj" in self.layer.prefix):
and "o_proj" in self.layer.prefix):
x = F.pad(x, (0, 0, 0, pad_size)) x = F.pad(x, (0, 0, 0, pad_size))
world_size = self.layer.tp_size world_size = self.layer.tp_size
comm_mode = "aiv" comm_mode = "aiv"
hcom_name = get_tp_group().device_group._get_backend( hcom_name = get_tp_group().device_group._get_backend(torch.device("npu")).get_hccl_comm_name(self.layer.tp_rank)
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.quantization.method_adapters import AscendLinearMethod from vllm_ascend.quantization.method_adapters import AscendLinearMethod
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
# For unquant # For unquant
if mmrs_fusion and isinstance(self.layer.quant_method, if mmrs_fusion and isinstance(self.layer.quant_method, UnquantizedLinearMethod):
UnquantizedLinearMethod):
output = torch_npu.npu_mm_reduce_scatter_base( output = torch_npu.npu_mm_reduce_scatter_base(
x, x,
self.layer.weight.t(), self.layer.weight.t(),
@@ -603,19 +551,22 @@ class SequenceRowParallelOp(CustomRowParallelOp):
reduce_op="sum", reduce_op="sum",
bias=None, bias=None,
comm_turn=0, comm_turn=0,
comm_mode=comm_mode) comm_mode=comm_mode,
)
if bias_ is not None: if bias_ is not None:
output.add_(bias_) output.add_(bias_)
# For w8a8 quant # For w8a8 quant
elif mmrs_fusion and ( elif mmrs_fusion and (
isinstance(self.layer.quant_method, AscendLinearMethod) isinstance(self.layer.quant_method, AscendLinearMethod)
and isinstance(self.layer.quant_method.quant_method, and isinstance(self.layer.quant_method.quant_method, AscendW8A8LinearMethod)
AscendW8A8LinearMethod)): ):
if x.dtype != torch.int8: if x.dtype != torch.int8:
x_quant = torch.ops.vllm.quantize( x_quant = torch.ops.vllm.quantize(
x, self.layer.aclnn_input_scale, x,
self.layer.aclnn_input_scale,
self.layer.aclnn_input_scale_reciprocal, self.layer.aclnn_input_scale_reciprocal,
self.layer.aclnn_input_offset) self.layer.aclnn_input_offset,
)
else: else:
x_quant = x x_quant = x
quant_bias = self.layer.quant_bias quant_bias = self.layer.quant_bias
@@ -631,14 +582,11 @@ class SequenceRowParallelOp(CustomRowParallelOp):
comm_turn=0, comm_turn=0,
x2_scale=deq_scale, x2_scale=deq_scale,
output_dtype=output_dtype, output_dtype=output_dtype,
comm_mode=comm_mode) comm_mode=comm_mode,
output = torch.add( )
output, output = torch.add(output, torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
else: else:
output_parallel = self.layer.quant_method.apply(self.layer, output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
x,
bias=bias_)
output = tensor_model_parallel_reduce_scatter(output_parallel, 0) output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
return output return output
@@ -651,13 +599,10 @@ class SequenceRowParallelOp(CustomRowParallelOp):
class ShardedCPRowParallelOp(CustomRowParallelOp): class ShardedCPRowParallelOp(CustomRowParallelOp):
@property @property
def comm_group(self): def comm_group(self):
# fake comm group to bypass tp logic # fake comm group to bypass tp logic
return SimpleNamespace(world_size=1, return SimpleNamespace(world_size=1, rank_in_group=0, device_group=None)
rank_in_group=0,
device_group=None)
def apply_impl( def apply_impl(
self, self,
@@ -677,13 +622,10 @@ class ShardedCPRowParallelOp(CustomRowParallelOp):
class ShardedCPColumnParallelOp(CustomColumnParallelOp): class ShardedCPColumnParallelOp(CustomColumnParallelOp):
@property @property
def comm_group(self): def comm_group(self):
# fake comm group to bypass tp logic # fake comm group to bypass tp logic
return SimpleNamespace(world_size=1, return SimpleNamespace(world_size=1, rank_in_group=0, device_group=None)
rank_in_group=0,
device_group=None)
def apply_impl( def apply_impl(
self, self,
@@ -700,12 +642,10 @@ class ShardedCPColumnParallelOp(CustomColumnParallelOp):
def _get_column_parallel_op( def _get_column_parallel_op(
prefix, layer prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, ) -> MLPColumnParallelOp | SequenceColumnParallelOp | ShardedCPColumnParallelOp | Flashcomm2OshardQKVParallelOp | None:
ShardedCPColumnParallelOp, Flashcomm2OshardQKVParallelOp]]:
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix): if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
return ShardedCPColumnParallelOp(layer) return ShardedCPColumnParallelOp(layer)
if "gate_up_proj" in prefix and mlp_tp_enable( if "gate_up_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
) and not is_moe_layer(prefix):
return MLPColumnParallelOp(layer) return MLPColumnParallelOp(layer)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")): if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
@@ -729,9 +669,15 @@ def _get_column_parallel_op(
def _get_row_parallel_op( def _get_row_parallel_op(
prefix, layer prefix, layer
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, ) -> (
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, MLPRowParallelOp
SequenceRowParallelOp, ShardedCPRowParallelOp]]: | OProjRowParallelOp
| Flashcomm2OProjRowParallelOp
| MatmulAllreduceRowParallelOp
| SequenceRowParallelOp
| ShardedCPRowParallelOp
| None
):
if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix: if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix:
return ShardedCPRowParallelOp(layer) return ShardedCPRowParallelOp(layer)
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
@@ -760,16 +706,21 @@ def _get_row_parallel_op(
def get_parallel_op(disable_tp, prefix, layer, direct): def get_parallel_op(disable_tp, prefix, layer, direct):
if disable_tp or ("shared_experts" in prefix if disable_tp or ("shared_experts" in prefix and shared_expert_dp_enabled()):
and shared_expert_dp_enabled()):
return None, 0, 1 return None, 0, 1
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, custom_op: (
MLPRowParallelOp, OProjRowParallelOp, MLPColumnParallelOp
Flashcomm2OProjRowParallelOp, | SequenceColumnParallelOp
Flashcomm2OshardQKVParallelOp, | MLPRowParallelOp
MatmulAllreduceRowParallelOp, | OProjRowParallelOp
SequenceRowParallelOp, ShardedCPRowParallelOp, | Flashcomm2OProjRowParallelOp
ShardedCPColumnParallelOp]] = None | Flashcomm2OshardQKVParallelOp
| MatmulAllreduceRowParallelOp
| SequenceRowParallelOp
| ShardedCPRowParallelOp
| ShardedCPColumnParallelOp
| None
) = None
if direct == "row": if direct == "row":
custom_op = _get_row_parallel_op(prefix, layer) custom_op = _get_row_parallel_op(prefix, layer)
@@ -782,8 +733,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
return None, get_tp_group().rank_in_group, get_tp_group().world_size return None, get_tp_group().rank_in_group, get_tp_group().world_size
def get_replicated_op(disable_tp, prefix, def get_replicated_op(disable_tp, prefix, layer) -> CustomReplicatedOp | None:
layer) -> Optional[Union[CustomReplicatedOp]]:
if disable_tp: if disable_tp:
return None return None
@@ -791,24 +741,22 @@ def get_replicated_op(disable_tp, prefix,
def is_moe_layer(prefix: str) -> bool: def is_moe_layer(prefix: str) -> bool:
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_moe_params(): def get_moe_params():
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
config = vllm_config.model_config.hf_text_config config = vllm_config.model_config.hf_text_config
n_routed_experts = getattr(config, 'n_routed_experts', 0) n_routed_experts = getattr(config, "n_routed_experts", 0)
first_k_dense_replace = getattr(config, 'first_k_dense_replace', first_k_dense_replace = getattr(config, "first_k_dense_replace", float("inf"))
float('inf')) moe_layer_freq = getattr(config, "moe_layer_freq", 1)
moe_layer_freq = getattr(config, 'moe_layer_freq', 1)
return n_routed_experts, first_k_dense_replace, moe_layer_freq return n_routed_experts, first_k_dense_replace, moe_layer_freq
match = re.search(r'layers\.(\d+)\.', prefix) match = re.search(r"layers\.(\d+)\.", prefix)
if match is None: if match is None:
return False return False
layer_idx = int(match.group(1)) layer_idx = int(match.group(1))
n_routed_experts, first_k_dense_replace, moe_layer_freq = get_moe_params() n_routed_experts, first_k_dense_replace, moe_layer_freq = get_moe_params()
return (n_routed_experts is not None and layer_idx >= first_k_dense_replace return n_routed_experts is not None and layer_idx >= first_k_dense_replace and layer_idx % moe_layer_freq == 0
and layer_idx % moe_layer_freq == 0)

View File

@@ -17,13 +17,15 @@
import math import math
import os import os
from typing import Optional, Tuple
import torch import torch
import torch_npu import torch_npu
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, DeepseekScalingRotaryEmbedding,
YaRNScalingRotaryEmbedding) MRotaryEmbedding,
RotaryEmbedding,
YaRNScalingRotaryEmbedding,
)
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
@@ -31,8 +33,7 @@ if HAS_TRITON:
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, from vllm_ascend.utils import AscendDeviceType, enable_custom_op, get_ascend_device_type, has_rope, is_vl_model
get_ascend_device_type, has_rope, is_vl_model)
# Currently, rope ops used on npu requires detached cos && sin as inputs. # Currently, rope ops used on npu requires detached cos && sin as inputs.
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. # However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
@@ -54,17 +55,13 @@ _cos_slice: torch.Tensor = None
_sin_slice: torch.Tensor = None _sin_slice: torch.Tensor = None
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, device):
device):
global _cos_mla global _cos_mla
global _sin_mla global _sin_mla
global _cos global _cos
global _sin global _sin
if _cos_mla is not None or \ if _cos_mla is not None or _sin_mla is not None or _cos is not None or _sin is not None:
_sin_mla is not None or \
_cos is not None or \
_sin is not None:
return return
model_config = vllm_config.model_config model_config = vllm_config.model_config
@@ -72,36 +69,15 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
if model_config.use_mla: if model_config.use_mla:
rope_dim = model_config.hf_text_config.qk_rope_head_dim rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos_mla = torch.ones(max_num_batched_tokens, _cos_mla = torch.ones(max_num_batched_tokens, 1, 1, rope_dim, dtype=dtype, device=device)
1, _sin_mla = torch.zeros(max_num_batched_tokens, 1, 1, rope_dim, dtype=dtype, device=device)
1,
rope_dim,
dtype=dtype,
device=device)
_sin_mla = torch.zeros(max_num_batched_tokens,
1,
1,
rope_dim,
dtype=dtype,
device=device)
elif not is_vl_model(vllm_config) and has_rope(vllm_config): elif not is_vl_model(vllm_config) and has_rope(vllm_config):
rope_dim = model_config.get_head_size() rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next. # For models using partial rope like Qwen3-Next.
if hasattr(model_config.hf_text_config, "partial_rotary_factor"): if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
rope_dim = int(rope_dim * rope_dim = int(rope_dim * model_config.hf_text_config.partial_rotary_factor)
model_config.hf_text_config.partial_rotary_factor) _cos = torch.ones(1, max_num_batched_tokens, 1, rope_dim, dtype=dtype, device=device)
_cos = torch.ones(1, _sin = torch.zeros(1, max_num_batched_tokens, 1, rope_dim, dtype=dtype, device=device)
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
def get_cos_and_sin_mla(positions, use_cache=False): def get_cos_and_sin_mla(positions, use_cache=False):
@@ -139,8 +115,7 @@ def _record_cos_and_sin_cache_interleaved(cos_sin_cache):
if _cos_cache is not None or _sin_cache is not None: if _cos_cache is not None or _sin_cache is not None:
return return
hidden_dim = cos_sin_cache.shape[-1] // 2 hidden_dim = cos_sin_cache.shape[-1] // 2
cos_cache, sin_cache = cos_sin_cache.view(-1, 2, hidden_dim).repeat( cos_cache, sin_cache = cos_sin_cache.view(-1, 2, hidden_dim).repeat(1, 1, 2).chunk(2, dim=1)
1, 1, 2).chunk(2, dim=1)
_cos_cache = cos_cache.squeeze(1) _cos_cache = cos_cache.squeeze(1)
_sin_cache = sin_cache.squeeze(1) _sin_cache = sin_cache.squeeze(1)
@@ -151,16 +126,16 @@ def update_cos_sin(positions):
global _cos_slice global _cos_slice
global _sin_slice global _sin_slice
if _cos_sin_cache is None or \ if _cos_sin_cache is None or _cos is None or _sin is None:
_cos is None or \
_sin is None:
return return
num_tokens = positions.size(0) num_tokens = positions.size(0)
_cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view( _cos[:, :num_tokens] = (
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0] _cos_sin_cache.index_select(0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0]
_sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view( )
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1] _sin[:, :num_tokens] = (
_cos_sin_cache.index_select(0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1]
)
_cos_slice = _cos[:, :num_tokens] _cos_slice = _cos[:, :num_tokens]
_sin_slice = _sin[:, :num_tokens] _sin_slice = _sin[:, :num_tokens]
@@ -170,8 +145,7 @@ def get_cos_and_sin_slice():
def _custom_rotary_embedding_enabled(query, neox_style, head_size): def _custom_rotary_embedding_enabled(query, neox_style, head_size):
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op()
)
def _rope_forward_oot( def _rope_forward_oot(
@@ -180,8 +154,8 @@ def _rope_forward_oot(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
is_neox_style: bool, is_neox_style: bool,
offsets: Optional[torch.Tensor] = None offsets: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
query_shape, key_shape = query.shape, key.shape query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device: if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device) self.cos_sin_cache = self.cos_sin_cache.to(query.device)
@@ -189,8 +163,7 @@ def _rope_forward_oot(
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
cos, sin = get_cos_and_sin_slice() cos, sin = get_cos_and_sin_slice()
# adopt custom kernel path for rotary_embedding # adopt custom kernel path for rotary_embedding
if _custom_rotary_embedding_enabled( if _custom_rotary_embedding_enabled(query, is_neox_style, self.head_size):
query, is_neox_style, self.head_size):
query, key = torch.ops._C_ascend.rotary_embedding( query, key = torch.ops._C_ascend.rotary_embedding(
positions, positions,
query, query,
@@ -201,34 +174,31 @@ def _rope_forward_oot(
) )
return query.view(query_shape), key.view(key_shape) return query.view(query_shape), key.view(key_shape)
if offsets is not None: if offsets is not None:
raise NotImplementedError( raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
"Batched rotary embedding is currently not supported on NPU.")
else: else:
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ if (
-1] == 128 and cos is not None and sin is not None: is_neox_style
and self.head_size == 128
and self.cos_sin_cache.shape[-1] == 128
and cos is not None
and sin is not None
):
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True # This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1, query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size) key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
# Although this function modifies in-place, please retain the function's return value. # Although this function modifies in-place, please retain the function's return value.
# Otherwise, the graph fusion operation may fail. # Otherwise, the graph fusion operation may fail.
query, key = torch_npu.npu_apply_rotary_pos_emb( query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
query, key, cos, sin)
elif self.rotary_dim < self.head_size: elif self.rotary_dim < self.head_size:
if HAS_TRITON: if HAS_TRITON:
cos = cos.view(-1, self.rotary_dim) cos = cos.view(-1, self.rotary_dim)
sin = sin.view(-1, self.rotary_dim) sin = sin.view(-1, self.rotary_dim)
q = query.contiguous().view(query.shape[0], -1, q = query.contiguous().view(query.shape[0], -1, self.head_size)
self.head_size)
k = key.contiguous().view(key.shape[0], -1, self.head_size) k = key.contiguous().view(key.shape[0], -1, self.head_size)
query, key = torch.ops.vllm.rope_forward_triton(q, query, key = torch.ops.vllm.rope_forward_triton(
k, q, k, cos, sin, rope_dim=self.rotary_dim, is_neox_style=True
cos, )
sin,
rope_dim=self.rotary_dim,
is_neox_style=True)
return query.view(query_shape), key.view(key_shape) return query.view(query_shape), key.view(key_shape)
else: else:
num_tokens = query.shape[0] num_tokens = query.shape[0]
@@ -271,7 +241,6 @@ def _rope_forward_oot(
class AscendRotaryEmbedding(RotaryEmbedding): class AscendRotaryEmbedding(RotaryEmbedding):
def __init__( def __init__(
self, self,
head_size: int, head_size: int,
@@ -281,8 +250,7 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype)
is_neox_style, dtype)
_record_cos_sin_cache(self.cos_sin_cache) _record_cos_sin_cache(self.cos_sin_cache)
_record_cos_and_sin_cache_interleaved(self.cos_sin_cache) _record_cos_and_sin_cache_interleaved(self.cos_sin_cache)
@@ -291,18 +259,16 @@ class AscendRotaryEmbedding(RotaryEmbedding):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: torch.Tensor | None = None,
is_neox_style_override: Optional[bool] = None, is_neox_style_override: bool | None = None,
): ):
is_neox_style = self.is_neox_style is_neox_style = self.is_neox_style
if is_neox_style_override is not None: if is_neox_style_override is not None:
is_neox_style = is_neox_style_override is_neox_style = is_neox_style_override
return _rope_forward_oot(self, positions, query, key, is_neox_style, return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)
offsets)
class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding): class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
def __init__( def __init__(
self, self,
head_size: int, head_size: int,
@@ -322,10 +288,11 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
"extrapolation_factor": extrapolation_factor, "extrapolation_factor": extrapolation_factor,
"attn_factor": attn_factor, "attn_factor": attn_factor,
"beta_fast": beta_fast, "beta_fast": beta_fast,
"beta_slow": beta_slow "beta_slow": beta_slow,
} }
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(
is_neox_style, scaling_factor, dtype, **extra_kwargs) head_size, rotary_dim, max_position_embeddings, base, is_neox_style, scaling_factor, dtype, **extra_kwargs
)
_record_cos_sin_cache(self.cos_sin_cache) _record_cos_sin_cache(self.cos_sin_cache)
def forward_oot( def forward_oot(
@@ -333,16 +300,13 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: torch.Tensor | None = None,
is_neox_style_override: Optional[bool] = None, is_neox_style_override: bool | None = None,
): ):
return AscendRotaryEmbedding.forward_oot(self, positions, query, key, return AscendRotaryEmbedding.forward_oot(self, positions, query, key, offsets, is_neox_style_override)
offsets,
is_neox_style_override)
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
def __init__( def __init__(
self, self,
head_size: int, head_size: int,
@@ -370,18 +334,17 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
self.beta_slow = beta_slow self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation. # Get n-d magnitude scaling corrected for interpolation.
self.mscale = float( self.mscale = float(
self._yarn_get_mscale(self.scaling_factor, float(mscale)) / self._yarn_get_mscale(self.scaling_factor, float(mscale))
self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * / self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
attn_factor) * attn_factor
super(DeepseekScalingRotaryEmbedding, )
self).__init__(head_size, rotary_dim, max_position_embeddings, super(DeepseekScalingRotaryEmbedding, self).__init__(
base, is_neox_style, dtype) head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
# NOTE: For ascend friendly computing, reorder sin and cos cache # NOTE: For ascend friendly computing, reorder sin and cos cache
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor) self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
self._set_cos_sin_cache(self.max_seq_len, self._set_cos_sin_cache(self.max_seq_len, device=NPUPlatform.device_type, dtype=dtype)
device=NPUPlatform.device_type,
dtype=dtype)
def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float: def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float:
if scale <= 1: if scale <= 1:
@@ -398,48 +361,27 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
# Note: The if conditional branch is not used here # Note: The if conditional branch is not used here
# to solve MTP compilation error. # to solve MTP compilation error.
max_value += (min_value == max_value).float() * 0.001 max_value += (min_value == max_value).float() * 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - linear_func = (torch.arange(dim, dtype=torch.float32) - min_value) / (max_value - min_value)
min_value) / (max_value - min_value)
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
# Inverse dim formula to find dim based on number of rotations # Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(self, def _yarn_find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048):
num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error. # Note: use torch instead of math to solve MTP compilation error.
return (dim * torch.log( return (dim * torch.log(torch.tensor(max_position_embeddings) / (num_rotations * 2 * torch.pi))) / (
torch.tensor(max_position_embeddings) / 2 * torch.log(torch.tensor(base))
(num_rotations * 2 * torch.pi))) / (2 * )
torch.log(torch.tensor(base)))
# Find dim range bounds based on rotations # Find dim range bounds based on rotations
def _yarn_find_correction_range(self, def _yarn_find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error. # Note: use torch instead of math to solve MTP compilation error.
low = torch.floor( low = torch.floor(self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
self._yarn_find_correction_dim(low_rot, dim, base, high = torch.ceil(self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
max_position_embeddings))
high = torch.ceil(
self._yarn_find_correction_dim(high_rot, dim, base,
max_position_embeddings))
# Note: use torch instead of max/min to solve MTP compilation error. # Note: use torch instead of max/min to solve MTP compilation error.
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def _apply_rotary_pos_emb(self, def _apply_rotary_pos_emb(self, q, k, cos, sin, position_ids, unsqueeze_dim=1):
q,
k,
cos,
sin,
position_ids,
unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
q (`torch.Tensor`): The query tensor. q (`torch.Tensor`): The query tensor.
@@ -451,11 +393,11 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
used to pass offsetted position ids when working with a KV-cache. used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example,
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim].
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly,
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
@@ -488,10 +430,10 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
def _set_cos_sin_cache(self, max_seq_len, device, dtype): def _set_cos_sin_cache(self, max_seq_len, device, dtype):
dim = self.rotary_dim dim = self.rotary_dim
freq_extra = 1.0 / (self.base**( freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) freq_inter = 1.0 / (
freq_inter = 1.0 / (self.scaling_factor * self.base**( self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) )
low, high = self._yarn_find_correction_range( low, high = self._yarn_find_correction_range(
self.beta_fast, self.beta_fast,
@@ -500,10 +442,8 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
self.base, self.base,
self.max_position_embeddings, self.max_position_embeddings,
) )
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask( inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
low, high, dim // 2).to(device=device, dtype=torch.float32) inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
inv_freq = freq_inter * (1 -
inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(max_seq_len, device=device, dtype=torch.float32) t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
@@ -513,20 +453,16 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
cos_cached = cos_cached.to(dtype) cos_cached = cos_cached.to(dtype)
sin_cached = sin_cached.to(dtype) sin_cached = sin_cached.to(dtype)
cache = torch.cat( cache = torch.cat([freqs.cos() * self.mscale, freqs.sin() * self.mscale], dim=-1).to(dtype)
[freqs.cos() * self.mscale,
freqs.sin() * self.mscale], dim=-1).to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False)
_record_cos_sin_cache(cache) _record_cos_sin_cache(cache)
_record_cos_and_sin_cache(cos_cached, sin_cached) _record_cos_and_sin_cache(cos_cached, sin_cached)
def forward(self, def forward(
positions: torch.Tensor, self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: torch.Tensor | None = None
query: torch.Tensor, ):
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None):
if len(key.shape) == 2: if len(key.shape) == 2:
key = key[:, None, :] key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style # Note: we implement the non neox_style method with shuffle the last dim and neox style
@@ -535,26 +471,24 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
is_neox_style = True is_neox_style = True
if self.is_neox_style is False: if self.is_neox_style is False:
b, h_q, d = query.shape b, h_q, d = query.shape
query = query.view(b, h_q, d // 2, query = query.view(b, h_q, d // 2, 2).transpose(3, 2).reshape(b, h_q, d)
2).transpose(3, 2).reshape(b, h_q, d)
b, h_k, d = key.shape b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3, key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
2).reshape(b, h_k, d) q_pe, k_pe = _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
is_neox_style, offsets)
return q_pe, k_pe return q_pe, k_pe
class AscendMRotaryEmbedding(MRotaryEmbedding): class AscendMRotaryEmbedding(MRotaryEmbedding):
# Empirical safety threshold for large Triton grids on Ascend NPU # Empirical safety threshold for large Triton grids on Ascend NPU
_ASCEND_TRITON_GRID_LIMIT = 65535 _ASCEND_TRITON_GRID_LIMIT = 65535
def forward_triton(self, def forward_triton(
self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor | None = None, key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None): offsets: torch.Tensor | None = None,
):
assert positions.ndim == 2 assert positions.ndim == 2
assert key is not None assert key is not None
@@ -573,8 +507,7 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
# When the grid becomes large, enable TRITON_ALL_BLOCKS_PARALLEL # When the grid becomes large, enable TRITON_ALL_BLOCKS_PARALLEL
# to avoid scheduler/runtime failures. # to avoid scheduler/runtime failures.
if (query_shape[0] > self._ASCEND_TRITON_GRID_LIMIT and if query_shape[0] > self._ASCEND_TRITON_GRID_LIMIT and os.environ.get("TRITON_ALL_BLOCKS_PARALLEL") != "1":
os.environ.get("TRITON_ALL_BLOCKS_PARALLEL") != "1"):
os.environ["TRITON_ALL_BLOCKS_PARALLEL"] = "1" os.environ["TRITON_ALL_BLOCKS_PARALLEL"] = "1"
q, k = triton_mrope( q, k = triton_mrope(
@@ -600,35 +533,37 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
# todo: need cann update in 8.5.0 # todo: need cann update in 8.5.0
return self.forward_triton(positions, query, key) return self.forward_triton(positions, query, key)
if self.mrope_section != [16, 24, 24] or \ if self.mrope_section != [16, 24, 24] or get_ascend_device_type() == AscendDeviceType.A5:
get_ascend_device_type() == AscendDeviceType.A5:
return super().forward_oot(positions, query, key) return super().forward_oot(positions, query, key)
import torch_npu import torch_npu
mrope_section = [0, 0, 0
] if positions.ndim == 1 else self.mrope_section mrope_section = [0, 0, 0] if positions.ndim == 1 else self.mrope_section
if self.cos_sin_cache.device != query.device: # type: ignore if self.cos_sin_cache.device != query.device: # type: ignore
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
query.device) # type: ignore query.device
) # type: ignore
if self.cos_sin_cache.dtype != query.dtype: # type: ignore if self.cos_sin_cache.dtype != query.dtype: # type: ignore
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
query.dtype) # type: ignore query.dtype
) # type: ignore
query, key = torch_npu.npu_mrope(positions.contiguous(), query, key = torch_npu.npu_mrope(
positions.contiguous(),
query.contiguous(), query.contiguous(),
key.contiguous(), key.contiguous(),
self.cos_sin_cache.contiguous(), self.cos_sin_cache.contiguous(),
self.head_size, self.head_size,
mrope_section=mrope_section, mrope_section=mrope_section,
rotary_mode='half') rotary_mode="half",
)
return query, key return query, key
class AscendApplyRotaryEmb(ApplyRotaryEmb): class AscendApplyRotaryEmb(ApplyRotaryEmb):
def __init__( def __init__(
self, self,
enforce_enable: bool = False, enforce_enable: bool = False,
@@ -647,8 +582,7 @@ class AscendApplyRotaryEmb(ApplyRotaryEmb):
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
x, cos, sin, origin_shape, origin_dtype = self._pre_process( x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
x, cos, sin)
head_dim = x.shape[-1] head_dim = x.shape[-1]
# cos, sin: [seq_len, head_dim // 2] # cos, sin: [seq_len, head_dim // 2]

View File

@@ -24,15 +24,12 @@ TOKEN_TYPE_MULTIPLIER = 1 << 30
TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1 TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1
def _encode_token_type_ids(input_ids: torch.Tensor, def _encode_token_type_ids(input_ids: torch.Tensor, token_type_ids: torch.Tensor) -> None:
token_type_ids: torch.Tensor) -> None:
# input_ids can be padded to the right # input_ids can be padded to the right
input_ids[:token_type_ids.shape[0]].bitwise_or_(token_type_ids * input_ids[: token_type_ids.shape[0]].bitwise_or_(token_type_ids * TOKEN_TYPE_MULTIPLIER)
TOKEN_TYPE_MULTIPLIER)
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER
input_ids.bitwise_and_(TOKEN_MASK) input_ids.bitwise_and_(TOKEN_MASK)

View File

@@ -0,0 +1,54 @@
from itertools import islice
import torch
from vllm.distributed import get_pp_group
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model, _get_llama_4_scaling
from vllm.sequence import IntermediateTensors
def forward(
self,
input_ids,
positions,
intermediate_tensors,
inputs_embeds,
):
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# Compute llama 4 scaling once per forward pass if enabled
# Note(wxy): This is a hack fix to avoid graph mode error for torch 2.8
# We'll find a better way to remove this patch.
try:
llama_4_scaling_config = self.config.llama_4_scaling
except AttributeError:
llama_4_scaling_config = None
llama_4_scaling: torch.Tensor | None
if llama_4_scaling_config is not None:
llama_4_scaling = _get_llama_4_scaling(
original_max_position_embeddings=llama_4_scaling_config["original_max_position_embeddings"],
scaling_beta=llama_4_scaling_config["beta"],
positions=positions,
)
else:
llama_4_scaling = None
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual, llama_4_scaling)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
DeepseekV2Model.forward = forward

View File

@@ -15,29 +15,25 @@
# limitations under the License. # limitations under the License.
# #
from typing import List, Optional, Union
import torch import torch
import vllm import vllm
from torch.distributed import Backend from torch.distributed import Backend
from vllm.distributed.parallel_state import (GroupCoordinator, from vllm.distributed.parallel_state import GroupCoordinator, _get_unique_name, _register_group
_get_unique_name, _register_group)
from vllm_ascend.distributed.device_communicators.npu_communicator import \ from vllm_ascend.distributed.device_communicators.npu_communicator import NPUCommunicator
NPUCommunicator
from vllm_ascend.utils import create_hccl_pg_options from vllm_ascend.utils import create_hccl_pg_options
class GroupCoordinatorPatch(GroupCoordinator): class GroupCoordinatorPatch(GroupCoordinator):
def __init__( def __init__(
self, self,
group_ranks: list[list[int]], group_ranks: list[list[int]],
local_rank: int, local_rank: int,
torch_distributed_backend: Union[str, Backend], torch_distributed_backend: str | Backend,
use_device_communicator: bool, # whether to use device communicator use_device_communicator: bool, # whether to use device communicator
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: str | None = None,
): ):
group_name = group_name or "anonymous" group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name) self.unique_name = _get_unique_name(group_name)
@@ -52,9 +48,8 @@ class GroupCoordinatorPatch(GroupCoordinator):
for ranks in group_ranks: for ranks in group_ranks:
device_group = torch.distributed.new_group( device_group = torch.distributed.new_group(
ranks, ranks, backend=torch_distributed_backend, pg_options=hccl_pg_options
backend=torch_distributed_backend, )
pg_options=hccl_pg_options)
# a group with `gloo` backend, to allow direct coordination between # a group with `gloo` backend, to allow direct coordination between
# processes through the CPU. # processes through the CPU.
@@ -83,22 +78,23 @@ class GroupCoordinatorPatch(GroupCoordinator):
unique_name=self.unique_name, unique_name=self.unique_name,
) )
from vllm.distributed.device_communicators.shm_broadcast import \ from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
MessageQueue
self.mq_broadcaster: Optional[MessageQueue] = None self.mq_broadcaster: MessageQueue | None = None
if use_message_queue_broadcaster and self.world_size > 1: if use_message_queue_broadcaster and self.world_size > 1:
self.mq_broadcaster = MessageQueue.create_from_process_group( self.mq_broadcaster = MessageQueue.create_from_process_group(self.cpu_group, 1 << 22, 6)
self.cpu_group, 1 << 22, 6)
self.use_custom_op_call = False self.use_custom_op_call = False
self.use_cpu_custom_send_recv = False self.use_cpu_custom_send_recv = False
def all_to_all(self, def all_to_all(
self,
input_: torch.Tensor, input_: torch.Tensor,
scatter_dim: int = 0, scatter_dim: int = 0,
gather_dim: int = -1, gather_dim: int = -1,
scatter_sizes: Optional[List[int]] = None, scatter_sizes: list[int] | None = None,
gather_sizes: Optional[List[int]] = None) -> torch.Tensor: gather_sizes: list[int] | None = None,
) -> torch.Tensor:
if self.world_size == 1: if self.world_size == 1:
return input_ return input_
assert -input_.dim() <= scatter_dim < input_.dim(), ( assert -input_.dim() <= scatter_dim < input_.dim(), (
@@ -108,9 +104,7 @@ class GroupCoordinatorPatch(GroupCoordinator):
f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}" f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}"
) )
assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1" assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1"
return self.device_communicator.all_to_all(input_, scatter_dim, return self.device_communicator.all_to_all(input_, scatter_dim, gather_dim, scatter_sizes, gather_sizes)
gather_dim, scatter_sizes,
gather_sizes)
def all_reduce(self, input_): def all_reduce(self, input_):
if self.world_size == 1: if self.world_size == 1:

View File

@@ -19,9 +19,11 @@ from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
_original_call = HunYuanVLProcessor.__call__ _original_call = HunYuanVLProcessor.__call__
def _patched_call(self, images=None, text=None, videos=None, **kwargs): def _patched_call(self, images=None, text=None, videos=None, **kwargs):
"""Remove add_special_tokens requirement.""" """Remove add_special_tokens requirement."""
kwargs.pop("add_special_tokens", None) kwargs.pop("add_special_tokens", None)
return _original_call(self, images=images, text=text, videos=videos, **kwargs) return _original_call(self, images=images, text=text, videos=videos, **kwargs)
HunYuanVLProcessor.__call__ = _patched_call HunYuanVLProcessor.__call__ = _patched_call

View File

@@ -13,7 +13,6 @@ def _argsort(tensor, *args, **kwargs):
class _TorchWrapper: class _TorchWrapper:
def __init__(self): def __init__(self):
self._raw_torch = torch self._raw_torch = torch
@@ -32,5 +31,6 @@ def patch_torch_npu_argsort():
global _is_patched global _is_patched
if not _is_patched: if not _is_patched:
import vllm.v1.attention.backends.gdn_attn as gdn_attn import vllm.v1.attention.backends.gdn_attn as gdn_attn
gdn_attn.torch = _TorchWrapper() gdn_attn.torch = _TorchWrapper()
_is_patched = True _is_patched = True

View File

@@ -18,8 +18,7 @@
import torch import torch
import vllm import vllm
from vllm.model_executor.models.utils import (_embedding_count_expression, from vllm.model_executor.models.utils import _embedding_count_expression, _flatten_embeddings
_flatten_embeddings)
from vllm.multimodal import NestedTensors from vllm.multimodal import NestedTensors

View File

@@ -20,26 +20,20 @@ from einops import rearrange
from torch import nn from torch import nn
from vllm.config import CUDAGraphMode from vllm.config import CUDAGraphMode
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import ( from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
from vllm.model_executor.models.qwen3_next import (Qwen3NextGatedDeltaNet,
fused_gdn_gating)
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \ from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
fused_qkvzba_split_reshape_cat from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
from vllm_ascend.ops.triton.fla.sigmoid_gating import \
fused_sigmoid_gating_delta_rule_update
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -61,10 +55,8 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
forward_context = get_forward_context() forward_context = get_forward_context()
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
# triton grid should be less than 66536 # triton grid should be less than 66536
divide_grid = projected_states_qkvz.shape[0] * triton.cdiv( divide_grid = projected_states_qkvz.shape[0] * triton.cdiv(self.num_k_heads, self.tp_size)
self.num_k_heads, self.tp_size) if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph and divide_grid < 65536:
if self.num_v_heads // self.num_k_heads in [1, 2, 4] and \
is_cuda_graph and divide_grid < 65536:
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
projected_states_qkvz, projected_states_qkvz,
projected_states_ba, projected_states_ba,
@@ -74,10 +66,8 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
self.head_v_dim, self.head_v_dim,
) )
else: else:
query, key, value, z, b, a = self.fix_query_key_value_ordering( query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
projected_states_qkvz, projected_states_ba) query, key, value = map(lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value))
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
(query, key, value))
mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = torch.cat((query, key, value), dim=-1)
# ============================================================ # ============================================================
@@ -150,16 +140,14 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
a = a[:num_actual_tokens] a = a[:num_actual_tokens]
# 1. Convolution sequence transformation # 1. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
self.conv1d.weight.size(2))
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
mixed_qkv_spec = mixed_qkv mixed_qkv_spec = mixed_qkv
mixed_qkv_non_spec = None mixed_qkv_non_spec = None
else: else:
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
mixed_qkv_non_spec = mixed_qkv.index_select( mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
0, non_spec_token_indx)
else: else:
mixed_qkv_spec = None mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv mixed_qkv_non_spec = mixed_qkv
@@ -172,8 +160,7 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
conv_state_indices=spec_state_indices_tensor[:, 0] conv_state_indices=spec_state_indices_tensor[:, 0][: attn_metadata.num_spec_decodes],
[:attn_metadata.num_spec_decodes],
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc, query_start_loc=spec_query_start_loc,
max_query_len=spec_state_indices_tensor.size(-1), max_query_len=spec_state_indices_tensor.size(-1),
@@ -204,21 +191,16 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
conv_state_indices= conv_state_indices=non_spec_state_indices_tensor[: attn_metadata.num_actual_tokens],
non_spec_state_indices_tensor[:attn_metadata.
num_actual_tokens],
validate_data=True, validate_data=True,
) )
else: else:
mixed_qkv_non_spec = None mixed_qkv_non_spec = None
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec)
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
mixed_qkv_non_spec)
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
g, beta = fused_gdn_gating_patch(self.A_log, a, b, g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
self.dt_bias)
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
g_spec = g g_spec = g
@@ -248,8 +230,7 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
beta=beta_spec, beta=beta_spec,
initial_state=ssm_state, initial_state=ssm_state,
inplace_final_state=True, inplace_final_state=True,
cu_seqlens=spec_query_start_loc[:attn_metadata. cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
num_spec_decodes + 1],
ssm_state_indices=spec_state_indices_tensor, ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
@@ -259,8 +240,7 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
# 2.2: Process the remaining part # 2.2: Process the remaining part
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
initial_state = ssm_state[ initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0 initial_state[~has_initial_state, ...] = 0
( (
core_attn_out_non_spec, core_attn_out_non_spec,
@@ -278,12 +258,9 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
) )
# Init cache # Init cache
ssm_state[ ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)
non_spec_state_indices_tensor] = last_recurrent_state.to(
ssm_state.dtype)
elif attn_metadata.num_decodes > 0: elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = ( core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
fused_recurrent_gated_delta_rule(
q=query_non_spec, q=query_non_spec,
k=key_non_spec, k=key_non_spec,
v=value_non_spec, v=value_non_spec,
@@ -291,11 +268,10 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
beta=beta_non_spec, beta=beta_non_spec,
initial_state=ssm_state, initial_state=ssm_state,
inplace_final_state=True, inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[:attn_metadata. cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
num_decodes + 1],
ssm_state_indices=non_spec_state_indices_tensor, ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
)) )
else: else:
core_attn_out_non_spec, last_recurrent_state = None, None core_attn_out_non_spec, last_recurrent_state = None, None
@@ -324,14 +300,12 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
device=core_attn_out_non_spec.device, device=core_attn_out_non_spec.device,
) )
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
merged_out.index_copy_(1, non_spec_token_indx, merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
core_attn_out_non_spec)
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
elif spec_sequence_masks is not None: elif spec_sequence_masks is not None:
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
else: else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze( core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
0)
Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward

View File

@@ -1,6 +1,7 @@
import torch import torch
import vllm.v1.worker.utils as utils import vllm.v1.worker.utils as utils
from vllm.v1.worker.utils import defaultdict, extract_layer_index from vllm.v1.worker.utils import defaultdict, extract_layer_index
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is
if vllm_version_is("v0.15.0"): if vllm_version_is("v0.15.0"):
@@ -8,6 +9,7 @@ if vllm_version_is("v0.15.0"):
else: else:
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
# Without this patch, it will raise an exception when initialize kv_cache. # Without this patch, it will raise an exception when initialize kv_cache.
# TODO To remove the patch, we need check why the original bind_kv_cache raises an NotImplementedError. # TODO To remove the patch, we need check why the original bind_kv_cache raises an NotImplementedError.
def bind_kv_cache( def bind_kv_cache(
@@ -38,8 +40,7 @@ def bind_kv_cache(
# Convert kv_caches dict to a list of tensors in the order of layer_index. # Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list) index2name = defaultdict(list)
for layer_name in kv_caches: for layer_name in kv_caches:
index2name[extract_layer_index(layer_name, index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name)
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()): for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index] layer_names = index2name[layer_index]

View File

@@ -1,8 +1,6 @@
import vllm.v1.sample.rejection_sampler as rs import vllm.v1.sample.rejection_sampler as rs
from vllm_ascend.sample.rejection_sampler import (apply_sampling_constraints, from vllm_ascend.sample.rejection_sampler import apply_sampling_constraints, expand_batch_to_tokens, rejection_sample
expand_batch_to_tokens,
rejection_sample)
# TODO: delete this patch after apply_sampling_constraints and rejection_sample # TODO: delete this patch after apply_sampling_constraints and rejection_sample
# are extracted to as class func of RejectionSampler # are extracted to as class func of RejectionSampler

View File

@@ -17,12 +17,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.layers.rotary_embedding.base import \ from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbeddingBase
RotaryEmbeddingBase
class AscendRotaryEmbeddingBase(nn.Module): class AscendRotaryEmbeddingBase(nn.Module):
def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]: def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]:
cos_sin = self.cos_sin_cache[:seqlen] cos_sin = self.cos_sin_cache[:seqlen]
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)

View File

@@ -3,16 +3,15 @@ import vllm.v1.worker.gpu.sample.gumbel
from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn
from vllm_ascend.ops.triton.fla.sigmoid_gating import \ from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_recurrent_gated_delta_rule_fwd_kernel
fused_recurrent_gated_delta_rule_fwd_kernel from vllm_ascend.ops.triton.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update_npu
from vllm_ascend.ops.triton.mamba.causal_conv1d import ( from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample as ascend_gumbel_sample
causal_conv1d_fn, causal_conv1d_update_npu)
from vllm_ascend.worker.v2.sample.gumbel import \
gumbel_sample as ascend_gumbel_sample
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = (
fused_recurrent_gated_delta_rule_fwd_kernel
)
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn
vllm.model_executor.layers.fla.ops.chunk_gated_delta_rule = chunk_gated_delta_rule vllm.model_executor.layers.fla.ops.chunk_gated_delta_rule = chunk_gated_delta_rule
vllm.v1.worker.gpu.sample.gumbel.gumbel_sample = ascend_gumbel_sample vllm.v1.worker.gpu.sample.gumbel.gumbel_sample = ascend_gumbel_sample

View File

@@ -36,11 +36,14 @@ def unquantized_gemm_fake(
return torch.empty(output_shape, dtype=x.dtype, device=x.device) return torch.empty(output_shape, dtype=x.dtype, device=x.device)
direct_register_custom_op(op_name="unquantized_gemm", direct_register_custom_op(
op_name="unquantized_gemm",
op_func=unquantized_gemm, op_func=unquantized_gemm,
fake_impl=unquantized_gemm_fake, fake_impl=unquantized_gemm_fake,
mutates_args=[], mutates_args=[],
dispatch_key="PrivateUse1") dispatch_key="PrivateUse1",
)
def default_unquantized_gemm( def default_unquantized_gemm(
layer: torch.nn.Module, layer: torch.nn.Module,

View File

@@ -19,11 +19,10 @@
import numpy as np import numpy as np
import torch import torch
import vllm import vllm
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.spec_decode.eagle import prepare_eagle_decode, prepare_eagle_inputs
from vllm.v1.worker.gpu.spec_decode.eagle import (prepare_eagle_decode,
prepare_eagle_inputs)
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
@@ -54,8 +53,7 @@ def propose(
# seq_lens) of the target model. # seq_lens) of the target model.
if aux_hidden_states: if aux_hidden_states:
assert self.method == "eagle3" assert self.method == "eagle3"
hidden_states = self.model.combine_hidden_states( hidden_states = self.model.combine_hidden_states(torch.cat(aux_hidden_states, dim=-1))
torch.cat(aux_hidden_states, dim=-1))
else: else:
hidden_states = last_hidden_states hidden_states = last_hidden_states
num_tokens = input_batch.num_tokens_after_padding num_tokens = input_batch.num_tokens_after_padding
@@ -95,19 +93,12 @@ def propose(
seeds = self.seeds[:num_reqs].clone() seeds = self.seeds[:num_reqs].clone()
pos = self.input_buffers.positions[:num_reqs].clone() pos = self.input_buffers.positions[:num_reqs].clone()
# Gather the values and copy them to the pre-allocated buffers. # Gather the values and copy them to the pre-allocated buffers.
torch.gather(sampling_metadata.temperature, torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature)
0,
cu_num_logits,
out=temperature)
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds) torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
torch.gather(input_batch.positions, 0, last_token_indices, out=pos) torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling. # used for draft and target sampling.
draft_tokens = gumbel_sample(logits, draft_tokens = gumbel_sample(logits, temperature, seeds, pos + 1, apply_temperature=True)
temperature,
seeds,
pos + 1,
apply_temperature=True)
if self.num_speculative_steps == 1: if self.num_speculative_steps == 1:
# Early exit. # Early exit.
return draft_tokens.view(-1, 1) return draft_tokens.view(-1, 1)
@@ -128,8 +119,7 @@ def propose(
) )
query_start_loc = self.input_buffers.query_start_loc query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings( slot_mappings = self.block_tables.compute_slot_mappings(query_start_loc_gpu, pos)
query_start_loc_gpu, pos)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None: if cudagraph_size is not None:
@@ -158,8 +148,7 @@ def propose(
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
) )
self.generate_draft(num_reqs, attn_metadata, self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME
num_tokens_across_dp=None) # FIXME
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]

View File

@@ -23,15 +23,13 @@ from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingStates from vllm.v1.pool.metadata import PoolingStates
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, from vllm.v1.sample.logits_processor import BatchUpdateBuilder, LogitsProcessors
LogitsProcessors)
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_ascend.worker.block_table import MultiGroupBlockTable from vllm_ascend.worker.block_table import MultiGroupBlockTable
class NPUInputBatch(InputBatch): class NPUInputBatch(InputBatch):
def __init__( def __init__(
self, self,
max_num_reqs: int, max_num_reqs: int,
@@ -72,10 +70,9 @@ class NPUInputBatch(InputBatch):
pin_memory=False, pin_memory=False,
) )
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids_tensor = torch.zeros((max_num_reqs, max_model_len), self.is_token_ids_tensor = torch.zeros(
device="cpu", (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
dtype=bool, )
pin_memory=False)
self.is_token_ids = self.is_token_ids_tensor.numpy() self.is_token_ids = self.is_token_ids_tensor.numpy()
# Store prompt embeddings per request to avoid OOM from large upfront # Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big. # allocation if max_model_len is big.
@@ -90,8 +87,7 @@ class NPUInputBatch(InputBatch):
dtype=torch.int32, dtype=torch.int32,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy( self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
)
# Block table. # Block table.
self.block_table = MultiGroupBlockTable( self.block_table = MultiGroupBlockTable(
@@ -107,34 +103,21 @@ class NPUInputBatch(InputBatch):
) )
# Sampling-related. # Sampling-related.
self.temperature = torch.empty((max_num_reqs, ), self.temperature = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
dtype=torch.float32, self.temperature_cpu_tensor = torch.empty(
device=device) (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), )
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set() self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set() self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs, ), self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
dtype=torch.float32, self.top_p_cpu_tensor = torch.empty((max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory)
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set() self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs, ), self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
dtype=torch.int32, self.top_k_cpu_tensor = torch.empty((max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory)
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set() self.top_k_reqs: set[str] = set()
@@ -142,54 +125,37 @@ class NPUInputBatch(InputBatch):
self.spec_decode_unsupported_reqs: set[str] = set() self.spec_decode_unsupported_reqs: set[str] = set()
# Frequency penalty related data structures # Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ), self.frequency_penalties = torch.empty((max_num_reqs,), dtype=torch.float, device=device)
dtype=torch.float,
device=device)
self.frequency_penalties_cpu_tensor = torch.empty( self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ), (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy(
) )
self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set() self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures # Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ), self.presence_penalties = torch.empty((max_num_reqs,), dtype=torch.float, device=device)
dtype=torch.float, self.presence_penalties_cpu_tensor = torch.empty(
device=device) (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
) )
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
self.presence_penalties_reqs: set[str] = set() self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures # Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ), self.repetition_penalties = torch.empty((max_num_reqs,), dtype=torch.float, device=device)
dtype=torch.float,
device=device)
self.repetition_penalties_cpu_tensor = torch.empty( self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ), (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy(
) )
self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set() self.repetition_penalties_reqs: set[str] = set()
# Speculative decoding # Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ), self.num_accepted_tokens_cpu_tensor = torch.ones(
dtype=torch.int64, (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
device="cpu",
pin_memory=pin_memory)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy(
) )
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ), self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
@@ -218,8 +184,7 @@ class NPUInputBatch(InputBatch):
# req_index -> bad_words_token_ids # req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {} self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
dtype=bool)
self.req_output_token_ids: list[list[int] | None] = [] self.req_output_token_ids: list[list[int] | None] = []
@@ -229,8 +194,7 @@ class NPUInputBatch(InputBatch):
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
# Store last speculative tokens for sampler. # Store last speculative tokens for sampler.
self.spec_token_ids: list[list[int]] = [[] self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
for _ in range(max_num_reqs)]
# This is updated each time the batch constituents change. # This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata()

View File

@@ -22,19 +22,16 @@ from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.cudagraph_utils import \ from vllm.v1.worker.gpu.cudagraph_utils import prepare_inputs_to_capture as prepare_inputs_to_capture_gpu
prepare_inputs_to_capture as prepare_inputs_to_capture_gpu
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
class AclGraphManager(CudaGraphManager): class AclGraphManager(CudaGraphManager):
"""ACL Graph Manager for Ascend NPUs.""" """ACL Graph Manager for Ascend NPUs."""
@@ -51,7 +48,7 @@ class AclGraphManager(CudaGraphManager):
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> None: ) -> None:
with (torch_cuda_wrapper(), prepare_capture_inputs_wrapper()): with torch_cuda_wrapper(), prepare_capture_inputs_wrapper():
super().capture_graph( super().capture_graph(
num_tokens, num_tokens,
model, model,

View File

@@ -18,19 +18,17 @@
# #
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Tuple from typing import Any
import numpy as np import numpy as np
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, AscendPrefillContextParallelMetadata
AscendPrefillContextParallelMetadata)
_ATTENTION_MASK_BUILDER = None _ATTENTION_MASK_BUILDER = None
@@ -59,8 +57,7 @@ def build_attn_metadata(
attn_state: Any | None = None, attn_state: Any | None = None,
graph_pad_size: int = -1, graph_pad_size: int = -1,
num_input_tokens: int = 0, num_input_tokens: int = 0,
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata | None = None,
| None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Build attention metadata for Ascend NPUs.""" """Build attention metadata for Ascend NPUs."""
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata. # TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
@@ -92,7 +89,8 @@ def build_attn_metadata(
graph_pad_size=graph_pad_size, graph_pad_size=graph_pad_size,
num_input_tokens=num_input_tokens, num_input_tokens=num_input_tokens,
prefill_context_parallel_metadata=prefill_context_parallel_metadata, prefill_context_parallel_metadata=prefill_context_parallel_metadata,
max_seq_len=max_seq_len) max_seq_len=max_seq_len,
)
attn_metadata_builder = attn_metadata_builders[i] attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build( metadata = attn_metadata_builder.build(
@@ -126,16 +124,14 @@ def build_attn_state(
# but only one token is not hit in cache. # but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1): elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly attn_state = AscendAttentionState.DecodeOnly
if (vllm_config.speculative_config if vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp":
and vllm_config.speculative_config.method == 'mtp'):
# SpecDecoding now supports seq_len=1 and seq_len=2 # SpecDecoding now supports seq_len=1 and seq_len=2
# In Prefilling Decoding Disaggregation scenario, SpecDecoding # In Prefilling Decoding Disaggregation scenario, SpecDecoding
# need to supports seq_len=1 # need to supports seq_len=1
attn_state = AscendAttentionState.SpecDecoding attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding. # Speculative decoding.
elif np.all(num_valid_tokens == 1): elif np.all(num_valid_tokens == 1):
if (vllm_config.speculative_config if vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp":
and vllm_config.speculative_config.method == 'mtp'):
attn_state = AscendAttentionState.SpecDecoding attn_state = AscendAttentionState.SpecDecoding
else: else:
attn_state = AscendAttentionState.ChunkedPrefill attn_state = AscendAttentionState.ChunkedPrefill

View File

@@ -22,15 +22,16 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu.input_batch import (InputBatch, from vllm.v1.worker.gpu.input_batch import (
InputBatch,
combine_sampled_and_draft_tokens, combine_sampled_and_draft_tokens,
prepare_pos_seq_lens, prepare_pos_seq_lens,
prepare_prefill_inputs) prepare_prefill_inputs,
)
from vllm.v1.worker.gpu.model_runner import GPUModelRunner from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata, from vllm_ascend.worker.v2.attn_utils import build_attn_metadata, build_attn_state
build_attn_state)
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
from vllm_ascend.worker.v2.sample.sampler import AscendSampler from vllm_ascend.worker.v2.sample.sampler import AscendSampler
from vllm_ascend.worker.v2.spec_decode import init_speculator from vllm_ascend.worker.v2.spec_decode import init_speculator
@@ -45,7 +46,7 @@ class NPUModelRunner(GPUModelRunner):
"""Model runner for Ascend NPUs.""" """Model runner for Ascend NPUs."""
def __init__(self, vllm_config: VllmConfig, device: torch.device): def __init__(self, vllm_config: VllmConfig, device: torch.device):
with (torch_cuda_wrapper(), uva_wrapper()): with torch_cuda_wrapper(), uva_wrapper():
super().__init__(vllm_config, device) super().__init__(vllm_config, device)
# because we will override these attribute, delete these attribute to # because we will override these attribute, delete these attribute to
@@ -94,7 +95,8 @@ class NPUModelRunner(GPUModelRunner):
# we need to adjust triton operators in sampler, # we need to adjust triton operators in sampler,
# so reinitialize sampler here. # so reinitialize sampler here.
self.sampler: AscendSampler = AscendSampler( self.sampler: AscendSampler = AscendSampler(
logprobs_mode=self.model_config.logprobs_mode, ) logprobs_mode=self.model_config.logprobs_mode,
)
# we need to copy num_computed_tokens back to cpu to help # we need to copy num_computed_tokens back to cpu to help
# update actual seq_lens_cpu. gpu attention backend doesn't need these # update actual seq_lens_cpu. gpu attention backend doesn't need these
@@ -131,16 +133,12 @@ class NPUModelRunner(GPUModelRunner):
self._update_seq_lens_cpu(scheduler_output, req_ids) self._update_seq_lens_cpu(scheduler_output, req_ids)
num_scheduled_tokens = np.array( num_scheduled_tokens = np.array([scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32)
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
dtype=np.int32)
num_valid_tokens = num_scheduled_tokens num_valid_tokens = num_scheduled_tokens
if scheduler_output.scheduled_spec_decode_tokens: if scheduler_output.scheduled_spec_decode_tokens:
num_valid_tokens = np.array( num_valid_tokens = np.array(
[ [
num_tokens - len( num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
scheduler_output.scheduled_spec_decode_tokens.get(
i, []))
for num_tokens, i in zip(num_scheduled_tokens, req_ids) for num_tokens, i in zip(num_scheduled_tokens, req_ids)
], ],
dtype=np.int32, dtype=np.int32,
@@ -153,9 +151,7 @@ class NPUModelRunner(GPUModelRunner):
num_valid_tokens, num_valid_tokens,
) )
idx_mapping_list = [ idx_mapping_list = [self.req_states.req_id_to_index[req_id] for req_id in req_ids]
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
idx_mapping = self.input_buffers.idx_mapping idx_mapping = self.input_buffers.idx_mapping
idx_mapping.np[:num_reqs] = idx_mapping_list idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = idx_mapping.np[:num_reqs] idx_mapping_np = idx_mapping.np[:num_reqs]
@@ -167,16 +163,11 @@ class NPUModelRunner(GPUModelRunner):
# No draft token scheduled (common case). # No draft token scheduled (common case).
total_num_draft_tokens = 0 total_num_draft_tokens = 0
total_num_logits = num_reqs total_num_logits = num_reqs
cu_num_logits = torch.arange(num_reqs + 1, cu_num_logits = torch.arange(num_reqs + 1, device=self.device, dtype=torch.int32)
device=self.device,
dtype=torch.int32)
else: else:
draft_tokens = scheduler_output.scheduled_spec_decode_tokens draft_tokens = scheduler_output.scheduled_spec_decode_tokens
num_draft_tokens = np.array( num_draft_tokens = np.array(
[ [len(draft_tokens[req_id]) if req_id in draft_tokens else 0 for req_id in req_ids],
len(draft_tokens[req_id]) if req_id in draft_tokens else 0
for req_id in req_ids
],
dtype=np.int32, dtype=np.int32,
) )
total_num_draft_tokens = int(num_draft_tokens.sum()) total_num_draft_tokens = int(num_draft_tokens.sum())
@@ -186,8 +177,7 @@ class NPUModelRunner(GPUModelRunner):
num_draft_tokens + 1, num_draft_tokens + 1,
out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1], out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
) )
cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu( cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)
num_reqs + 1)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping_npu) block_tables = self.block_tables.gather_block_tables(idx_mapping_npu)
@@ -201,14 +191,9 @@ class NPUModelRunner(GPUModelRunner):
# Some attention backends like FA3 require query_start_loc to be non-decreasing. # Some attention backends like FA3 require query_start_loc to be non-decreasing.
self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
self.input_buffers.query_start_loc.copy_to_gpu() self.input_buffers.query_start_loc.copy_to_gpu()
query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
num_reqs + query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
1] query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[:
num_reqs +
1]
query_start_loc_np = self.input_buffers.query_start_loc.np[:num_reqs +
1]
# Get prefill tokens. # Get prefill tokens.
prepare_prefill_inputs( prepare_prefill_inputs(
@@ -249,7 +234,8 @@ class NPUModelRunner(GPUModelRunner):
# Compute slot mappings: [num_kv_cache_groups, num_tokens] # Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings( slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions[:num_tokens]) query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
)
# Layer name -> attention metadata. # Layer name -> attention metadata.
# TODO(Ronald1995): try to add a new method `build_attn_metadata` in # TODO(Ronald1995): try to add a new method `build_attn_metadata` in
@@ -263,8 +249,7 @@ class NPUModelRunner(GPUModelRunner):
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens, seq_lens=self.input_buffers.seq_lens,
seq_lens_np=self.input_buffers.seq_lens_np, seq_lens_np=self.input_buffers.seq_lens_np,
num_computed_tokens_cpu=self.req_states. num_computed_tokens_cpu=self.req_states.num_computed_tokens_cpu[idx_mapping_cpu],
num_computed_tokens_cpu[idx_mapping_cpu],
block_tables=block_tables, block_tables=block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
@@ -335,16 +320,13 @@ class NPUModelRunner(GPUModelRunner):
req_index = self.req_states.req_id_to_index[req_id] req_index = self.req_states.req_id_to_index[req_id]
# num_computed_tokens_cpu has reverted by num_rejected_tokens already. # num_computed_tokens_cpu has reverted by num_rejected_tokens already.
# in super postprocess method. # in super postprocess method.
self.req_states.num_computed_tokens_cpu[ self.req_states.num_computed_tokens_cpu[req_index] = self.num_computed_tokens_cpu[req_index]
req_index] = self.num_computed_tokens_cpu[req_index]
# update seq_lens_cpu # update seq_lens_cpu
for i, req_id in enumerate(req_ids): for i, req_id in enumerate(req_ids):
req_index = self.req_states.req_id_to_index[req_id] req_index = self.req_states.req_id_to_index[req_id]
num_computed_tokens = self.req_states.num_computed_tokens_cpu[ num_computed_tokens = self.req_states.num_computed_tokens_cpu[req_index]
req_index] self.input_buffers.seq_lens_cpu[i] = num_computed_tokens + num_scheduled_tokens[req_id]
self.input_buffers.seq_lens_cpu[
i] = num_computed_tokens + num_scheduled_tokens[req_id]
def eplb_warmup(self): def eplb_warmup(self):
# TODO(Ronald1995): just define the method in case calling error in # TODO(Ronald1995): just define the method in case calling error in

View File

@@ -76,8 +76,7 @@ def _gumbel_sample_kernel(
idx = tl.argmax(logits, axis=0) idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0) value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value) tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)

View File

@@ -68,8 +68,7 @@ def _penalties_and_temperature_kernel(
if use_penalty: if use_penalty:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load( output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
block,
mask=mask, mask=mask,
) )
# to use vector core, if use > 0 will use scalar to slow down performance # to use vector core, if use > 0 will use scalar to slow down performance
@@ -77,11 +76,9 @@ def _penalties_and_temperature_kernel(
# Apply repetition penalties. # Apply repetition penalties.
if use_rep_penalty: if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange( packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
0, BLOCK_SIZE // 32)
packed_mask = tl.load( packed_mask = tl.load(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32), mask=packed_block < tl.cdiv(vocab_size, 32),
) )
# the compiler itself does not optimize right-shift operations, so we change the same func # the compiler itself does not optimize right-shift operations, so we change the same func
@@ -97,8 +94,7 @@ def _penalties_and_temperature_kernel(
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE) prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op. # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty. # If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale) logits *= tl.where(logits > 0, 1.0 / scale, scale)

View File

@@ -16,18 +16,16 @@
# #
import torch import torch
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.sample.min_p import apply_min_p from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
from vllm_ascend.worker.v2.sample.penalties import \ from vllm_ascend.worker.v2.sample.penalties import apply_penalties_and_temperature
apply_penalties_and_temperature
class AscendSampler(Sampler): class AscendSampler(Sampler):
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
@@ -45,8 +43,7 @@ class AscendSampler(Sampler):
if sampling_metadata.min_p is not None: if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.min_p) apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor. # Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(logits, sampling_metadata.top_k, logits = apply_top_k_top_p(logits, sampling_metadata.top_k, sampling_metadata.top_p)
sampling_metadata.top_p)
sampled = gumbel_sample( sampled = gumbel_sample(
logits, logits,

View File

@@ -30,9 +30,7 @@ def init_speculator(
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config
assert speculative_config is not None assert speculative_config is not None
if speculative_config.use_eagle(): if speculative_config.use_eagle():
from vllm_ascend.worker.v2.spec_decode.eagle import \ from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
AscendEagleSpeculator
return AscendEagleSpeculator(vllm_config, device) return AscendEagleSpeculator(vllm_config, device)
raise NotImplementedError( raise NotImplementedError(f"{speculative_config.method} is not supported yet.")
f"{speculative_config.method} is not supported yet.")

View File

@@ -30,7 +30,6 @@ from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
class AscendEagleSpeculator(EagleSpeculator): class AscendEagleSpeculator(EagleSpeculator):
def __init__(self, vllm_config: VllmConfig, device: torch.device): def __init__(self, vllm_config: VllmConfig, device: torch.device):
"""Override GPU EagleSpeculator.__init__ for Ascend NPUs. """Override GPU EagleSpeculator.__init__ for Ascend NPUs.
attnention metadata building in Ascend backend needs more information, attnention metadata building in Ascend backend needs more information,

View File

@@ -63,8 +63,8 @@ class AscendRequestState(RequestState):
# NOTE(Ronald1995): Ascend NPUs do not support UVA yet, # NOTE(Ronald1995): Ascend NPUs do not support UVA yet,
# so we use CpuGpuBuffer to allocate prefill_token_ids buffer. # so we use CpuGpuBuffer to allocate prefill_token_ids buffer.
self.prefill_token_ids: CpuGpuBuffer = self._make_buffer( # type: ignore self.prefill_token_ids: CpuGpuBuffer = self._make_buffer( # type: ignore
(self.max_num_reqs, self.max_model_len), (self.max_num_reqs, self.max_model_len), dtype=torch.int32
dtype=torch.int32) )
def add_request( def add_request(
self, self,
@@ -75,7 +75,6 @@ class AscendRequestState(RequestState):
sampling_params, sampling_params,
lora_request, lora_request,
): ):
super().add_request( super().add_request(
req_id, req_id,
prompt_len, prompt_len,
@@ -93,7 +92,6 @@ def uva_wrapper():
"""Context manager to disable UVA for Ascend NPUs.""" """Context manager to disable UVA for Ascend NPUs."""
class UvaBufferWrapper: class UvaBufferWrapper:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass

View File

@@ -20,7 +20,6 @@
import copy import copy
import gc import gc
from types import NoneType from types import NoneType
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -29,12 +28,9 @@ import vllm.envs as envs_vllm
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
from torch_npu.profiler import dynamic_profile as dp from torch_npu.profiler import dynamic_profile as dp
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
init_distributed_environment)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized, get_kv_transfer_group, has_kv_transfer_group
get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger from vllm.logger import logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@@ -44,8 +40,7 @@ from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
@@ -56,28 +51,28 @@ from vllm_ascend.cpu_binding import bind_cpus
from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type, from vllm_ascend.utils import (
enable_sp, get_ascend_device_type, AscendDeviceType,
register_ascend_customop) check_ascend_device_type,
enable_sp,
get_ascend_device_type,
register_ascend_customop,
)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402 torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402 from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402
from vllm.utils.torch_utils import set_random_seed # noqa: E402
from vllm.utils.torch_utils import set_random_seed
torch_non_c_binding_in_graph_functions_npu = dict.fromkeys( torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(
["torch.npu.current_stream"], ["torch.npu.current_stream"],
TorchInGraphFunctionVariable, TorchInGraphFunctionVariable,
) # noqa: E402 ) # noqa: E402
torch_non_c_binding_in_graph_functions_npu[ torch_non_c_binding_in_graph_functions_npu["torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402
"torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402 torch._dynamo.trace_rules.torch_name_rule_map.append(torch_non_c_binding_in_graph_functions_npu) # noqa: E402
torch._dynamo.trace_rules.torch_name_rule_map.append(
torch_non_c_binding_in_graph_functions_npu) # noqa: E402
class NPUWorker(WorkerBase): class NPUWorker(WorkerBase):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
@@ -86,7 +81,8 @@ class NPUWorker(WorkerBase):
distributed_init_method: str, distributed_init_method: str,
is_driver_worker: bool = False, is_driver_worker: bool = False,
# Additional parameters for compatibility with vllm # Additional parameters for compatibility with vllm
**kwargs): **kwargs,
):
"""Initialize the worker for Ascend.""" """Initialize the worker for Ascend."""
if not envs_ascend.COMPILE_CUSTOM_KERNELS: if not envs_ascend.COMPILE_CUSTOM_KERNELS:
logger.warning( logger.warning(
@@ -96,14 +92,17 @@ class NPUWorker(WorkerBase):
# register patch for vllm # register patch for vllm
from vllm_ascend.utils import adapt_patch from vllm_ascend.utils import adapt_patch
adapt_patch() adapt_patch()
# Import _inductor for graph mode execution with triton # Import _inductor for graph mode execution with triton
# This lazy import avoids torch_npu re-initialization in patch # This lazy import avoids torch_npu re-initialization in patch
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
if HAS_TRITON: if HAS_TRITON:
import torch_npu._inductor # noqa: F401 import torch_npu._inductor # noqa: F401
# Register ops when worker init. # Register ops when worker init.
from vllm_ascend import ops from vllm_ascend import ops
ops.register_dummy_fusion_op() ops.register_dummy_fusion_op()
if get_ascend_device_type() != AscendDeviceType.A5: if get_ascend_device_type() != AscendDeviceType.A5:
_register_atb_extensions() _register_atb_extensions()
@@ -112,17 +111,18 @@ class NPUWorker(WorkerBase):
init_ascend_config(vllm_config) init_ascend_config(vllm_config)
check_ascend_device_type() check_ascend_device_type()
super().__init__(vllm_config=vllm_config, super().__init__(
vllm_config=vllm_config,
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker,
)
if self.cache_config.cache_dtype == "auto": if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype self.cache_dtype = self.model_config.dtype
else: else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
self.cache_config.cache_dtype]
self.profiler = self._init_profiler() self.profiler = self._init_profiler()
if vllm_config.model_config and vllm_config.model_config.enable_sleep_mode: if vllm_config.model_config and vllm_config.model_config.enable_sleep_mode:
@@ -130,8 +130,8 @@ class NPUWorker(WorkerBase):
self._sleep_saved_buffers: dict[str, torch.Tensor] = {} self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# FixMe: this is a patch to fix the issue cause by https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170 # FixMe: this is a patch to fix the issue cause by https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170
from vllm.model_executor.layers.linear import \ from vllm.model_executor.layers.linear import WEIGHT_LOADER_V2_SUPPORTED
WEIGHT_LOADER_V2_SUPPORTED
if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED:
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
@@ -151,33 +151,33 @@ class NPUWorker(WorkerBase):
# Either SIGTERM or SIGINT will terminate the worker # Either SIGTERM or SIGINT will terminate the worker
import signal import signal
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
def uninstall_static_kernel(self): def uninstall_static_kernel(self):
import os
import fcntl import fcntl
import os
import subprocess import subprocess
ascend_home_path = os.environ["ASCEND_HOME_PATH"] ascend_home_path = os.environ["ASCEND_HOME_PATH"]
static_kernel_dir_path = os.path.join(ascend_home_path, 'opp/static_kernel') static_kernel_dir_path = os.path.join(ascend_home_path, "opp/static_kernel")
uninstall_script_path = os.path.join(static_kernel_dir_path, 'ai_core/uninstall.sh') uninstall_script_path = os.path.join(static_kernel_dir_path, "ai_core/uninstall.sh")
lock_file_path = os.path.join(static_kernel_dir_path, 'uninstall.lock') lock_file_path = os.path.join(static_kernel_dir_path, "uninstall.lock")
if not os.path.exists(uninstall_script_path): if not os.path.exists(uninstall_script_path):
return return
with open(lock_file_path, 'w') as lock_fd: with open(lock_file_path, "w") as lock_fd:
try: try:
fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
subprocess.Popen( subprocess.Popen(
['bash', uninstall_script_path], ["bash", uninstall_script_path],
stdin=subprocess.DEVNULL, stdin=subprocess.DEVNULL,
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
start_new_session=True start_new_session=True,
) )
except (BlockingIOError, OSError) as e: except (BlockingIOError, OSError):
return return
finally: finally:
try: try:
@@ -187,16 +187,12 @@ class NPUWorker(WorkerBase):
except Exception: except Exception:
return return
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.npu.mem_get_info()[0] free_bytes_before_sleep = torch.npu.mem_get_info()[0]
# Save the buffers before level 2 sleep # Save the buffers before level 2 sleep
if level == 2: if level == 2:
model = self.model_runner.model model = self.model_runner.model
self._sleep_saved_buffers = { self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()}
name: buffer.cpu().clone()
for name, buffer in model.named_buffers()
}
allocator = CaMemAllocator.get_instance() allocator = CaMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.npu.mem_get_info() free_bytes_after_sleep, total = torch.npu.mem_get_info()
@@ -204,15 +200,17 @@ class NPUWorker(WorkerBase):
used_bytes = total - free_bytes_after_sleep used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping." assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info( logger.info(
"Sleep mode freed %.2f GiB memory, " "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, freed_bytes / GiB_bytes,
used_bytes / GiB_bytes) used_bytes / GiB_bytes,
)
def wake_up(self, tags: Optional[list[str]] = None) -> None: def wake_up(self, tags: list[str] | None = None) -> None:
if envs_ascend.VLLM_ASCEND_ENABLE_NZ: if envs_ascend.VLLM_ASCEND_ENABLE_NZ:
raise ValueError( raise ValueError(
"FRACTAL_NZ mode is enabled. This may cause model parameter precision issues " "FRACTAL_NZ mode is enabled. This may cause model parameter precision issues "
"in the RL scenarios. Please set VLLM_ASCEND_ENABLE_NZ=0.") "in the RL scenarios. Please set VLLM_ASCEND_ENABLE_NZ=0."
)
allocator = CaMemAllocator.get_instance() allocator = CaMemAllocator.get_instance()
allocator.wake_up(tags=tags) allocator.wake_up(tags=tags)
@@ -220,22 +218,21 @@ class NPUWorker(WorkerBase):
model = self.model_runner.model model = self.model_runner.model
if tags is None or "weights" in tags: if tags is None or "weights" in tags:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if 'w2_weight' in name and param.shape[2] == hidden_size: if "w2_weight" in name and param.shape[2] == hidden_size:
parts = name.split('.') parts = name.split(".")
param_name = parts[-1] param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1])) parent_module = model.get_submodule(".".join(parts[:-1]))
w2_data = param.transpose(1, 2) w2_data = param.transpose(1, 2)
w2_data = torch.nn.Parameter(w2_data, requires_grad=False) w2_data = torch.nn.Parameter(w2_data, requires_grad=False)
setattr(parent_module, param_name, w2_data) setattr(parent_module, param_name, w2_data)
elif 'w13_weight' in name and param.shape[1] == hidden_size: elif "w13_weight" in name and param.shape[1] == hidden_size:
parts = name.split('.') parts = name.split(".")
param_name = parts[-1] param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1])) parent_module = model.get_submodule(".".join(parts[:-1]))
w13_data = param.transpose(1, 2) w13_data = param.transpose(1, 2)
w13_data = torch.nn.Parameter(w13_data, w13_data = torch.nn.Parameter(w13_data, requires_grad=False)
requires_grad=False)
setattr(parent_module, param_name, w13_data) setattr(parent_module, param_name, w13_data)
# Restore the buffers after level 2 sleep # Restore the buffers after level 2 sleep
@@ -245,8 +242,7 @@ class NPUWorker(WorkerBase):
buffer.data.copy_(self._sleep_saved_buffers[name].data) buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {} self._sleep_saved_buffers = {}
def initialize_cache(self, num_gpu_blocks: int, def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
@@ -255,18 +251,19 @@ class NPUWorker(WorkerBase):
torch.npu.set_device(device) torch.npu.set_device(device)
torch.npu.empty_cache() torch.npu.empty_cache()
if (self.parallel_config.data_parallel_size > 1 if (
self.parallel_config.data_parallel_size > 1
and self.parallel_config.data_parallel_size_local > 0 and self.parallel_config.data_parallel_size_local > 0
and self.parallel_config.distributed_executor_backend and self.parallel_config.distributed_executor_backend not in ["ray", "external_launcher"]
not in ["ray", "external_launcher"] and and self.vllm_config.parallel_config.data_parallel_backend != "ray"
self.vllm_config.parallel_config.data_parallel_backend != "ray" and self.vllm_config.parallel_config.nnodes_within_dp == 1
and self.vllm_config.parallel_config.nnodes_within_dp == 1): ):
visible_device_count = (torch.npu.device_count() visible_device_count = torch.npu.device_count() if torch.npu.is_available() else 0
if torch.npu.is_available() else 0)
assert self.parallel_config.local_world_size <= visible_device_count, ( assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must " f"local_world_size ({self.parallel_config.local_world_size}) must "
f"be less than or equal to the number of visible devices " f"be less than or equal to the number of visible devices "
f"({visible_device_count}).") f"({visible_device_count})."
)
self.init_npu_memory = torch.npu.mem_get_info()[0] self.init_npu_memory = torch.npu.mem_get_info()[0]
# Initialize the distributed environment. # Initialize the distributed environment.
@@ -281,9 +278,7 @@ class NPUWorker(WorkerBase):
try: try:
bind_cpus(self.local_rank) bind_cpus(self.local_rank)
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"Bind cpus failed in rank{self.local_rank}: {e} Skip binding cpu.")
f"Bind cpus failed in rank{self.local_rank}: {e} Skip binding cpu."
)
return device return device
def init_device(self): def init_device(self):
@@ -296,11 +291,9 @@ class NPUWorker(WorkerBase):
init_workspace_manager(self.device, num_ubatches) init_workspace_manager(self.device, num_ubatches)
# Init ModelRunner here, so that we have access to self.device. # Init ModelRunner here, so that we have access to self.device.
if self.use_v2_model_runner: if self.use_v2_model_runner:
logger.warning( logger.warning("npu model runner v2 is in developing, some features doesn't work for now.")
"npu model runner v2 is in developing, some features doesn't work for now." from vllm_ascend.worker.v2.model_runner import NPUModelRunner as NPUModelRunnerV2
)
from vllm_ascend.worker.v2.model_runner import \
NPUModelRunner as NPUModelRunnerV2
self.model_runner = NPUModelRunnerV2(self.vllm_config, self.device) self.model_runner = NPUModelRunnerV2(self.vllm_config, self.device)
else: else:
self.model_runner = NPUModelRunner(self.vllm_config, self.device) self.model_runner = NPUModelRunner(self.vllm_config, self.device)
@@ -327,27 +320,22 @@ class NPUWorker(WorkerBase):
"Error in memory profiling. " "Error in memory profiling. "
f"Initial free memory {self.init_npu_memory}, current free memory" f"Initial free memory {self.init_npu_memory}, current free memory"
f" {free_npu_memory}. This happens when the NPU memory was " f" {free_npu_memory}. This happens when the NPU memory was "
"not properly cleaned up before initializing the vLLM instance.") "not properly cleaned up before initializing the vLLM instance."
)
# Get the peak memory allocation recorded by torch # Get the peak memory allocation recorded by torch
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"] peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
# TODO: don`t need impl this func after empty_cache in # TODO: don`t need impl this func after empty_cache in
# Worker.determine_num_available_blocks() unified` # Worker.determine_num_available_blocks() unified`
torch.npu.empty_cache() torch.npu.empty_cache()
torch_allocated_bytes = torch_npu.npu.memory_stats( torch_allocated_bytes = torch_npu.npu.memory_stats()["allocated_bytes.all.current"]
)["allocated_bytes.all.current"] total_allocated_bytes = torch_npu.npu.mem_get_info()[1] - torch_npu.npu.mem_get_info()[0]
total_allocated_bytes = torch_npu.npu.mem_get_info(
)[1] - torch_npu.npu.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0: if non_torch_allocations > 0:
peak_memory += non_torch_allocations peak_memory += non_torch_allocations
available_kv_cache_memory = int( available_kv_cache_memory = int(total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory)
total_npu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
available_kv_cache_memory = int(max(available_kv_cache_memory, 0)) available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
logger.info( logger.info(f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}")
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
)
return available_kv_cache_memory return available_kv_cache_memory
def execute_model( def execute_model(
@@ -361,32 +349,30 @@ class NPUWorker(WorkerBase):
intermediate_tensors = None intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0 forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank: if forward_pass and not get_pp_group().is_first_rank:
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise it will conflict with the all-gather operation in flashcomm1. # If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# it will conflict with the all-gather operation in flashcomm1.
if enable_sp(): if enable_sp():
all_gather_group = None all_gather_group = None
else: else:
all_gather_group = get_tp_group() all_gather_group = get_tp_group()
intermediate_tensors = IntermediateTensors( intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict( get_pp_group().recv_tensor_dict(all_gather_group=all_gather_group)
all_gather_group=all_gather_group)) )
output = self.model_runner.execute_model(scheduler_output, output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
intermediate_tensors) if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput, NoneType)):
if isinstance(output,
(ModelRunnerOutput, AsyncModelRunnerOutput, NoneType)):
return output return output
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
assert parallel_config.distributed_executor_backend != ( assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
"external_launcher") and not get_pp_group().is_last_rank # If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise it will conflict with the all-gather operation in flashcomm1. # it will conflict with the all-gather operation in flashcomm1.
if enable_sp(): if enable_sp():
all_gather_group = None all_gather_group = None
else: else:
all_gather_group = get_tp_group() all_gather_group = get_tp_group()
get_pp_group().send_tensor_dict(output.tensors, get_pp_group().send_tensor_dict(output.tensors, all_gather_group=all_gather_group)
all_gather_group=all_gather_group)
kv_connector_output = output.kv_connector_output kv_connector_output = output.kv_connector_output
if not kv_connector_output: if not kv_connector_output:
@@ -394,28 +380,24 @@ class NPUWorker(WorkerBase):
# In case of PP with kv transfer, we need to pass through the # In case of PP with kv transfer, we need to pass through the
# kv_connector_output # kv_connector_output
if (not kv_connector_output.finished_sending if not kv_connector_output.finished_sending and not kv_connector_output.finished_recving:
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output output.kv_connector_output = kv_connector_output
return output return output
@torch.inference_mode() @torch.inference_mode()
def sample_tokens( def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput | AsyncModelRunnerOutput:
self, grammar_output: "GrammarOutput"
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output) return self.model_runner.sample_tokens(grammar_output)
def load_model(self) -> None: def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode: if self.vllm_config.model_config.enable_sleep_mode:
allocator = CaMemAllocator.get_instance() allocator = CaMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, ( assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process."
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights") context = allocator.use_memory_pool(tag="weights")
else: else:
from contextlib import nullcontext from contextlib import nullcontext
context = nullcontext() # type: ignore context = nullcontext() # type: ignore
with context, set_current_vllm_config(self.vllm_config): with context, set_current_vllm_config(self.vllm_config):
@@ -423,19 +405,15 @@ class NPUWorker(WorkerBase):
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> None:
# Note: need to adapt for graph mode. # Note: need to adapt for graph mode.
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes warmup_sizes = (self.vllm_config.compilation_config.compile_sizes or []).copy()
or []).copy()
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
cg_capture_sizes: list[int] = [] cg_capture_sizes: list[int] = []
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
cg_capture_sizes = [] if cg_sizes is None else cg_sizes cg_capture_sizes = [] if cg_sizes is None else cg_sizes
warmup_sizes = [ warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]
x for x in warmup_sizes if x not in cg_capture_sizes
]
compile_ranges = self.vllm_config.compilation_config.get_compile_ranges( compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
)
# For each compile_range, if none of the batch sizes # For each compile_range, if none of the batch sizes
# in warmup_sizes or cudagraph_capture_sizes are in the range, # in warmup_sizes or cudagraph_capture_sizes are in the range,
# add the end of the range to ensure compilation/warmup. # add the end of the range to ensure compilation/warmup.
@@ -467,7 +445,7 @@ class NPUWorker(WorkerBase):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
def get_kv_connector_handshake_metadata(self) -> Optional[dict]: def get_kv_connector_handshake_metadata(self) -> dict | None:
"""Get KV connector metadata from this worker if available.""" """Get KV connector metadata from this worker if available."""
if not has_kv_transfer_group(): if not has_kv_transfer_group():
return None return None
@@ -503,6 +481,7 @@ class NPUWorker(WorkerBase):
context = allocator.use_memory_pool(tag="kv_cache") context = allocator.use_memory_pool(tag="kv_cache")
else: else:
from contextlib import nullcontext from contextlib import nullcontext
context = nullcontext() # type: ignore context = nullcontext() # type: ignore
with context: with context:
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
@@ -528,21 +507,20 @@ class NPUWorker(WorkerBase):
return self.model_runner.pin_lora(lora_id) return self.model_runner.pin_lora(lora_id)
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run( self.model_runner._dummy_run(num_tokens=self.model_runner.decode_token_per_req, uniform_decode=True)
num_tokens=self.model_runner.decode_token_per_req,
uniform_decode=True)
def _init_worker_distributed_environment(self) -> None: def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
init_batch_invariance() init_batch_invariance()
init_distributed_environment(self.parallel_config.world_size, init_distributed_environment(
self.rank, self.distributed_init_method, self.parallel_config.world_size, self.rank, self.distributed_init_method, self.local_rank, "hccl"
self.local_rank, "hccl") )
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size, self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size, self.parallel_config.pipeline_parallel_size,
self.parallel_config.prefill_context_parallel_size, self.parallel_config.prefill_context_parallel_size,
self.parallel_config.decode_context_parallel_size) self.parallel_config.decode_context_parallel_size,
)
init_ascend_model_parallel(self.parallel_config) init_ascend_model_parallel(self.parallel_config)
ensure_kv_transfer_initialized(self.vllm_config) ensure_kv_transfer_initialized(self.vllm_config)
ensure_ec_transfer_initialized(self.vllm_config) ensure_ec_transfer_initialized(self.vllm_config)
@@ -553,12 +531,9 @@ class NPUWorker(WorkerBase):
profiler_config = self.vllm_config.profiler_config profiler_config = self.vllm_config.profiler_config
if profiler_config.profiler == "torch" and profiler_config.torch_profiler_dir: if profiler_config.profiler == "torch" and profiler_config.torch_profiler_dir:
if envs_ascend.MSMONITOR_USE_DAEMON: if envs_ascend.MSMONITOR_USE_DAEMON:
raise RuntimeError( raise RuntimeError("MSMONITOR_USE_DAEMON and torch profiler cannot be both enabled at the same time.")
"MSMONITOR_USE_DAEMON and torch profiler cannot be both enabled at the same time."
)
torch_profiler_trace_dir = profiler_config.torch_profiler_dir torch_profiler_trace_dir = profiler_config.torch_profiler_dir
logger.info("Profiling enabled. Traces will be saved to: %s", logger.info("Profiling enabled. Traces will be saved to: %s", torch_profiler_trace_dir)
torch_profiler_trace_dir)
experimental_config = torch_npu.profiler._ExperimentalConfig( experimental_config = torch_npu.profiler._ExperimentalConfig(
export_type=torch_npu.profiler.ExportType.Text, export_type=torch_npu.profiler.ExportType.Text,
@@ -583,8 +558,8 @@ class NPUWorker(WorkerBase):
# The with_stack option in torch_npu.profiler introduces significant time overhead. # The with_stack option in torch_npu.profiler introduces significant time overhead.
with_modules=profiler_config.torch_profiler_with_stack, with_modules=profiler_config.torch_profiler_with_stack,
experimental_config=experimental_config, experimental_config=experimental_config,
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler( on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_profiler_trace_dir),
torch_profiler_trace_dir)) )
else: else:
return None return None
@@ -594,5 +569,5 @@ class NPUWorker(WorkerBase):
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
return self.model_runner.get_supported_tasks() return self.model_runner.get_supported_tasks()
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.model_runner.take_draft_token_ids() return self.model_runner.take_draft_token_ids()

View File

@@ -14,49 +14,44 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from typing import Any, Callable, Tuple from collections.abc import Callable
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_ep_group, from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size, get_world_group
get_tensor_model_parallel_world_size,
get_world_group)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from xlite._C import (AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime, # type: ignore[attr-defined] from xlite._C import ( # type: ignore[attr-defined]
ScoringFuncSoftmax) AttnMHA,
Model,
ModelAttnMeta,
ModelConfig,
Runtime,
ScoringFuncSoftmax,
)
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import (AscendAttentionState, from vllm_ascend.attention.attention_v1 import AscendAttentionState, AscendMetadata
AscendMetadata)
class XliteModel: class XliteModel:
def initialize(self, runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
def initialize( raise NotImplementedError("Xlite Model initialize function not implemented.")
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
raise NotImplementedError(
"Xlite Model initialize function not implemented.")
class LlamaXliteModel(XliteModel): class LlamaXliteModel(XliteModel):
def initialize(self, runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
dtype = vllm_config.model_config.dtype dtype = vllm_config.model_config.dtype
config = self._build_model_config(vllm_config) config = self._build_model_config(vllm_config)
xlite_model = self._build_model(runnable, vllm_config, config) xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
xlite_model.init(config, rank) xlite_model.init(config, rank)
freq_cis = self._precompute_freqs_cis(config.head_dim, freq_cis = self._precompute_freqs_cis(config.head_dim, config.max_seq_len, dtype, config.rope_theta)
config.max_seq_len, dtype,
config.rope_theta)
return (xlite_model, freq_cis, config.hidden_size, dtype) return (xlite_model, freq_cis, config.hidden_size, dtype)
@@ -96,8 +91,7 @@ class LlamaXliteModel(XliteModel):
config.block_size = vllm_config.cache_config.block_size config.block_size = vllm_config.cache_config.block_size
return config return config
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, config: ModelConfig) -> Model:
config: ModelConfig) -> Model:
params_dict = dict(runnable.named_parameters()) params_dict = dict(runnable.named_parameters())
if hasattr(runnable, "language_model"): if hasattr(runnable, "language_model"):
@@ -108,48 +102,33 @@ class LlamaXliteModel(XliteModel):
model_prefix = "" model_prefix = ""
xlite_model = Model() xlite_model = Model()
xlite_model.embed = params_dict.get(model_prefix + xlite_model.embed = params_dict.get(model_prefix + "model.embed_tokens.weight")
"model.embed_tokens.weight")
xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight") xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight")
if vllm_config.model_config.hf_text_config.tie_word_embeddings: if vllm_config.model_config.hf_text_config.tie_word_embeddings:
xlite_model.head = xlite_model.embed xlite_model.head = xlite_model.embed
else: else:
xlite_model.head = params_dict.get(model_prefix + "lm_head.weight") xlite_model.head = params_dict.get(model_prefix + "lm_head.weight")
xlite_model.attn_norm = [ xlite_model.attn_norm = [layer.input_layernorm.weight for layer in layers]
layer.input_layernorm.weight for layer in layers xlite_model.attn_out = [layer.self_attn.o_proj.weight for layer in layers]
] xlite_model.mha_qkv = [layer.self_attn.qkv_proj.weight for layer in layers]
xlite_model.attn_out = [ xlite_model.mlp_norm = [layer.post_attention_layernorm.weight for layer in layers]
layer.self_attn.o_proj.weight for layer in layers
]
xlite_model.mha_qkv = [
layer.self_attn.qkv_proj.weight for layer in layers
]
xlite_model.mlp_norm = [
layer.post_attention_layernorm.weight for layer in layers
]
xlite_model.mlp_up_gate = [ xlite_model.mlp_up_gate = [
layer.mlp.gate_up_proj.weight for layer in layers layer.mlp.gate_up_proj.weight
if hasattr(layer.mlp, "gate_up_proj") for layer in layers
and layer.mlp.gate_up_proj.weight is not None if hasattr(layer.mlp, "gate_up_proj") and layer.mlp.gate_up_proj.weight is not None
] ]
xlite_model.mlp_down = [ xlite_model.mlp_down = [
layer.mlp.down_proj.weight for layer in layers layer.mlp.down_proj.weight
if hasattr(layer.mlp, "down_proj") for layer in layers
and layer.mlp.down_proj.weight is not None if hasattr(layer.mlp, "down_proj") and layer.mlp.down_proj.weight is not None
] ]
mha_qkv_bias = [ mha_qkv_bias = [
layer.self_attn.qkv_proj.bias for layer in layers layer.self_attn.qkv_proj.bias
if hasattr(layer.self_attn.qkv_proj, "bias") for layer in layers
and layer.self_attn.qkv_proj.bias is not None if hasattr(layer.self_attn.qkv_proj, "bias") and layer.self_attn.qkv_proj.bias is not None
]
q_norm = [
layer.self_attn.q_norm.weight for layer in layers
if hasattr(layer.self_attn, "q_norm")
]
k_norm = [
layer.self_attn.k_norm.weight for layer in layers
if hasattr(layer.self_attn, "k_norm")
] ]
q_norm = [layer.self_attn.q_norm.weight for layer in layers if hasattr(layer.self_attn, "q_norm")]
k_norm = [layer.self_attn.k_norm.weight for layer in layers if hasattr(layer.self_attn, "k_norm")]
if len(mha_qkv_bias) != config.n_layers: if len(mha_qkv_bias) != config.n_layers:
config.qkv_bias = False config.qkv_bias = False
@@ -157,7 +136,7 @@ class LlamaXliteModel(XliteModel):
config.qkv_bias = True config.qkv_bias = True
xlite_model.mha_qkv_bias = mha_qkv_bias xlite_model.mha_qkv_bias = mha_qkv_bias
if (len(q_norm) != config.n_layers or len(k_norm) != config.n_layers): if len(q_norm) != config.n_layers or len(k_norm) != config.n_layers:
config.qk_norm = False config.qk_norm = False
else: else:
config.qk_norm = True config.qk_norm = True
@@ -166,39 +145,28 @@ class LlamaXliteModel(XliteModel):
return xlite_model return xlite_model
def _precompute_freqs_cis(self, def _precompute_freqs_cis(self, dim: int, end: int, dtype: torch.dtype, theta: float = 10000.0):
dim: int, freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device="cpu")[: (dim // 2)] / dim))
end: int,
dtype: torch.dtype,
theta: float = 10000.0):
freqs = 1.0 / (theta**(torch.arange(
0, dim, 2, dtype=torch.float32, device='cpu')[:(dim // 2)] / dim))
t = torch.arange(end, device=freqs.device) # type: ignore t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore
cos_cache = freqs.cos().to(dtype) cos_cache = freqs.cos().to(dtype)
sin_cache = freqs.sin().to(dtype) sin_cache = freqs.sin().to(dtype)
freq_cis = torch.cat((cos_cache, sin_cache), dim=-1) freq_cis = torch.cat((cos_cache, sin_cache), dim=-1)
return freq_cis.to(device='npu') return freq_cis.to(device="npu")
class QwenMoeXliteModel(LlamaXliteModel): class QwenMoeXliteModel(LlamaXliteModel):
def initialize(self, runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2: if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2:
architecture = vllm_config.model_config.architectures[0] architecture = vllm_config.model_config.architectures[0]
raise ValueError( raise ValueError(f"{architecture} not support VLLM_ASCEND_ENABLE_NZ = 2!")
f"{architecture} not support VLLM_ASCEND_ENABLE_NZ = 2!")
dtype = vllm_config.model_config.dtype dtype = vllm_config.model_config.dtype
config = self._build_model_config(vllm_config) config = self._build_model_config(vllm_config)
xlite_model = self._build_model(runnable, vllm_config, config) xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
xlite_model.init(config, rank) xlite_model.init(config, rank)
freq_cis = super()._precompute_freqs_cis(config.head_dim, freq_cis = super()._precompute_freqs_cis(config.head_dim, config.max_seq_len, dtype, config.rope_theta)
config.max_seq_len, dtype,
config.rope_theta)
return (xlite_model, freq_cis, config.hidden_size, dtype) return (xlite_model, freq_cis, config.hidden_size, dtype)
@@ -220,26 +188,21 @@ class QwenMoeXliteModel(LlamaXliteModel):
config.scoring_func = ScoringFuncSoftmax # type: ignore config.scoring_func = ScoringFuncSoftmax # type: ignore
return config return config
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, config: ModelConfig) -> Model:
config: ModelConfig) -> Model:
xlite_model = super()._build_model(runnable, vllm_config, config) xlite_model = super()._build_model(runnable, vllm_config, config)
layers = runnable.model.layers layers = runnable.model.layers
xlite_model.gate = [layer.mlp.gate.weight for layer in layers] xlite_model.gate = [layer.mlp.gate.weight for layer in layers]
xlite_model.re_up_gate = [ xlite_model.re_up_gate = [
layer.mlp.experts.w13_weight[i] for layer in layers layer.mlp.experts.w13_weight[i] for layer in layers for i in range(layer.mlp.experts.local_num_experts)
for i in range(layer.mlp.experts.local_num_experts)
] ]
xlite_model.re_down = [ xlite_model.re_down = [
layer.mlp.experts.w2_weight[i] for layer in layers layer.mlp.experts.w2_weight[i] for layer in layers for i in range(layer.mlp.experts.local_num_experts)
for i in range(layer.mlp.experts.local_num_experts)
] ]
return xlite_model return xlite_model
def xlite_model_init( def xlite_model_init(runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
strategy_map = { strategy_map = {
"LlamaForCausalLM": LlamaXliteModel, "LlamaForCausalLM": LlamaXliteModel,
"Qwen2ForCausalLM": LlamaXliteModel, "Qwen2ForCausalLM": LlamaXliteModel,
@@ -266,33 +229,26 @@ class XliteWrapper:
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
local_rank = get_world_group().local_rank local_rank = get_world_group().local_rank
self.xlite_rt = Runtime(local_rank, 0, rank, self.xlite_rt = Runtime(
get_tensor_model_parallel_world_size(), local_rank, 0, rank, get_tensor_model_parallel_world_size(), vllm_config.parallel_config.data_parallel_size
vllm_config.parallel_config.data_parallel_size) )
(self.xlite_model, self.freq_cis, hidden_size, (self.xlite_model, self.freq_cis, hidden_size, dtype) = xlite_model_init(runnable, vllm_config)
dtype) = xlite_model_init(runnable, vllm_config)
rt_pool_size = self.xlite_model.get_tensor_pool_size() rt_pool_size = self.xlite_model.get_tensor_pool_size()
if rank == 0: if rank == 0:
logger.info(f"xlite runtime pool size: {rt_pool_size} MB") logger.info(f"xlite runtime pool size: {rt_pool_size} MB")
if self.xlite_rt.init_tensor_pool(rt_pool_size) != 0: if self.xlite_rt.init_tensor_pool(rt_pool_size) != 0:
raise ValueError( raise ValueError(f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB")
f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB"
)
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_states = torch.empty(max_num_tokens, self.hidden_states = torch.empty(max_num_tokens, hidden_size, device=f"npu:{local_rank}", dtype=dtype)
hidden_size,
device=f"npu:{local_rank}",
dtype=dtype)
def __getattr__(self, key: str): def __getattr__(self, key: str):
# allow accessing the attributes of the runnable. # allow accessing the attributes of the runnable.
if hasattr(self.runnable, key): if hasattr(self.runnable, key):
return getattr(self.runnable, key) return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of " raise AttributeError(f"Attribute {key} not exists in the runnable of xlite wrapper: {self.runnable}")
f"xlite wrapper: {self.runnable}")
def unwrap(self) -> Callable: def unwrap(self) -> Callable:
# in case we need to access the original runnable. # in case we need to access the original runnable.
@@ -307,22 +263,19 @@ class XliteWrapper:
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
list[torch.Tensor]]:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: Any = forward_context.attn_metadata attn_metadata: Any = forward_context.attn_metadata
if attn_metadata is None: if attn_metadata is None:
return self.runnable(input_ids, positions, intermediate_tensors, return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)
inputs_embeds)
attn_metadata = next(iter(attn_metadata.values()), None) attn_metadata = next(iter(attn_metadata.values()), None)
if attn_metadata is None or not isinstance(attn_metadata, if attn_metadata is None or not isinstance(attn_metadata, AscendMetadata):
AscendMetadata): return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)
return self.runnable(input_ids, positions, intermediate_tensors,
inputs_embeds)
with_prefill = attn_metadata.attn_state not in [ with_prefill = attn_metadata.attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding,
] ]
if not with_prefill or self.full_mode: if not with_prefill or self.full_mode:
@@ -335,11 +288,7 @@ class XliteWrapper:
num_prefills = attn_metadata.num_prefills num_prefills = attn_metadata.num_prefills
batch = num_prefills + num_decodes batch = num_prefills + num_decodes
seq_lens = attn_metadata.seq_lens[:batch] seq_lens = attn_metadata.seq_lens[:batch]
seq_tensor = torch.cat([ seq_tensor = torch.cat([torch.tensor([0]), torch.tensor(attn_metadata.actual_seq_lengths_q)], dim=0)
torch.tensor([0]),
torch.tensor(attn_metadata.actual_seq_lengths_q)
],
dim=0)
query_lens = seq_tensor[1:] - seq_tensor[:-1] query_lens = seq_tensor[1:] - seq_tensor[:-1]
query_lens = query_lens[:batch] query_lens = query_lens[:batch]
cached_lens = seq_lens - query_lens cached_lens = seq_lens - query_lens
@@ -347,23 +296,19 @@ class XliteWrapper:
xlite_attn_metadata = ModelAttnMeta() xlite_attn_metadata = ModelAttnMeta()
xlite_attn_metadata.lens = query_lens.tolist() xlite_attn_metadata.lens = query_lens.tolist()
xlite_attn_metadata.cached_lens = cached_lens.tolist() xlite_attn_metadata.cached_lens = cached_lens.tolist()
xlite_attn_metadata.is_prefills = [False] * num_decodes + [ xlite_attn_metadata.is_prefills = [False] * num_decodes + [True] * num_prefills
True xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu().tolist()
] * num_prefills
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu(
).tolist()
h = self.hidden_states[: attn_metadata.num_actual_tokens] h = self.hidden_states[: attn_metadata.num_actual_tokens]
stream = torch.npu.current_stream().npu_stream stream = torch.npu.current_stream().npu_stream
if inputs_embeds is None: if inputs_embeds is None:
self.xlite_model.forward(self.xlite_rt, input_ids, self.xlite_model.forward(
xlite_attn_metadata, self.kv_caches, self.xlite_rt, input_ids, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream
self.freq_cis, h, stream) )
else: else:
self.xlite_model.forward_with_inputs_embeds( self.xlite_model.forward_with_inputs_embeds(
self.xlite_rt, inputs_embeds, xlite_attn_metadata, self.xlite_rt, inputs_embeds, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream
self.kv_caches, self.freq_cis, h, stream) )
return h return h
else: else:
return self.runnable(input_ids, positions, intermediate_tensors, return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)
inputs_embeds)

View File

@@ -22,13 +22,13 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
class XliteModelRunner(NPUModelRunner): class XliteModelRunner(NPUModelRunner):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model.unwrap() return self.model.unwrap()
def load_model(self) -> None: def load_model(self) -> None:
super().load_model() super().load_model()
from vllm_ascend.xlite.xlite import XliteWrapper from vllm_ascend.xlite.xlite import XliteWrapper
self.model = XliteWrapper(self.model, self.vllm_config) self.model = XliteWrapper(self.model, self.vllm_config)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: