Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_roberta.py

92 lines
3.2 KiB
Python
Raw Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
[1/N][Refactor] Refactor code to adapt with vllm main (#3612) ### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 1. refactor deepseek to the latest code arch as of https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 2. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
from typing import Optional, Union
import torch
from vllm.model_executor.models.roberta import (
RobertaEmbedding, RobertaForSequenceClassification,
replace_roberta_positions)
from vllm.sequence import IntermediateTensors
# aclgraph does not support shift operator for now
# TODO: revert me when aclgraph supports shift operator
TOKEN_TYPE_SHIFT = 30
TOKEN_TYPE_MULTIPLIER = 1 << 30
TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1
def _encode_token_type_ids(input_ids: torch.Tensor,
token_type_ids: torch.Tensor) -> None:
# input_ids can be padded to the right
input_ids[:token_type_ids.shape[0]].bitwise_or_(token_type_ids *
TOKEN_TYPE_MULTIPLIER)
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER
input_ids.bitwise_and_(TOKEN_MASK)
return token_type_ids
def roberta_for_sequence_classification_forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
if token_type_ids is not None:
assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)
return self.roberta(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
def roberta_embedding_forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
[1/N][Refactor] Refactor code to adapt with vllm main (#3612) ### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 1. refactor deepseek to the latest code arch as of https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 2. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
inputs_embeds: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
[1/N][Refactor] Refactor code to adapt with vllm main (#3612) ### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 1. refactor deepseek to the latest code arch as of https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 2. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
RobertaEmbedding.forward = roberta_embedding_forward
RobertaForSequenceClassification.forward = roberta_for_sequence_classification_forward