Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -70,7 +70,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer,
|
||||
from vllm_kunlun.ops.activation import SiluAndMul
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend="aot_eager")
|
||||
@@ -640,7 +640,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view(
|
||||
last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1])
|
||||
cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1])
|
||||
xtorch_ops.reshape_and_cache_flash(
|
||||
kunlun_ops.reshape_and_cache_flash(
|
||||
last_recurrent_state,
|
||||
last_recurrent_state,
|
||||
cast_ssm_state,
|
||||
|
||||
@@ -85,7 +85,7 @@ from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
from einops import repeat
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
Reference in New Issue
Block a user