Merge pull request #80 from liwei109/aicapx-quant

[fix]matmul not support cuda graph
This commit is contained in:
baoqian426
2026-01-06 17:49:09 +08:00
committed by GitHub
3 changed files with 12 additions and 7 deletions

View File

@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from vllm_kunlun.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder,
Qwen3VLForConditionalGeneration,
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)

View File

@@ -64,8 +64,9 @@ def apply(
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
w1_bias = layer.w13_bias,
w2_bias = layer.w2_bias)
w1_bias=getattr(layer, 'w13_bias', None),
w2_bias=getattr(layer, 'w2_bias', None),
)
UnquantizedFusedMoEMethod.apply = apply

View File

@@ -1616,17 +1616,21 @@ def scaled_int8_quant_cuda(
return x_q, scale, azp, static
def fake_scaled_int8_quant(
def _fake_scaled_int8_quant(
x: torch.Tensor,
scale: torch.Tensor,
azp: Optional[torch.Tensor] = None,
symmetric: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
x_q = torch.ones(x.shape, dtype=torch.int8, device=x.device)
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
scale = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
return x_q, scale, azp, False
scaled_int8_quant.register_fake(fake_scaled_int8_quant)
scaled_int8_quant.register_fake(_fake_scaled_int8_quant)
######################################################
@@ -1821,7 +1825,7 @@ def _fake_matmul(
w_pc_max: torch.Tensor = None,
) -> torch.Tensor:
return torch.empty(
(x.shape[0], w.shape[0]),
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
dtype=out_dtype,
device=x.device,
)