diff --git a/vllm_kunlun/models/qwen3_vl_moe.py b/vllm_kunlun/models/qwen3_vl_moe.py index e0643c7..76e09e5 100644 --- a/vllm_kunlun/models/qwen3_vl_moe.py +++ b/vllm_kunlun/models/qwen3_vl_moe.py @@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors -from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from vllm_kunlun.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration, Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) diff --git a/vllm_kunlun/ops/fused_moe/layer.py b/vllm_kunlun/ops/fused_moe/layer.py index 953fbc0..8120868 100644 --- a/vllm_kunlun/ops/fused_moe/layer.py +++ b/vllm_kunlun/ops/fused_moe/layer.py @@ -64,8 +64,9 @@ def apply( topk_group=topk_group, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - w1_bias = layer.w13_bias, - w2_bias = layer.w2_bias) + w1_bias=getattr(layer, 'w13_bias', None), + w2_bias=getattr(layer, 'w2_bias', None), + ) UnquantizedFusedMoEMethod.apply = apply diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index faeb1e0..2753f71 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -1616,17 +1616,21 @@ def scaled_int8_quant_cuda( return x_q, scale, azp, static -def fake_scaled_int8_quant( +def _fake_scaled_int8_quant( x: torch.Tensor, scale: torch.Tensor, azp: Optional[torch.Tensor] = None, symmetric: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]: - x_q = torch.ones(x.shape, dtype=torch.int8, device=x.device) + x_q = torch.empty_like(x, dtype=torch.int8, device=x.device) + scale = torch.empty( + (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 + ) + azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32) return x_q, scale, azp, False -scaled_int8_quant.register_fake(fake_scaled_int8_quant) +scaled_int8_quant.register_fake(_fake_scaled_int8_quant) ###################################################### @@ -1821,7 +1825,7 @@ def _fake_matmul( w_pc_max: torch.Tensor = None, ) -> torch.Tensor: return torch.empty( - (x.shape[0], w.shape[0]), + (x.shape[0], w.shape[0] if w_trans else w.shape[1]), dtype=out_dtype, device=x.device, )