Merge pull request #80 from liwei109/aicapx-quant
[fix]matmul not support cuda graph
This commit is contained in:
@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
|
||||
from vllm_kunlun.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
|
||||
from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder,
|
||||
Qwen3VLForConditionalGeneration,
|
||||
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
|
||||
|
||||
@@ -64,8 +64,9 @@ def apply(
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
w1_bias = layer.w13_bias,
|
||||
w2_bias = layer.w2_bias)
|
||||
w1_bias=getattr(layer, 'w13_bias', None),
|
||||
w2_bias=getattr(layer, 'w2_bias', None),
|
||||
)
|
||||
|
||||
UnquantizedFusedMoEMethod.apply = apply
|
||||
|
||||
|
||||
@@ -1616,17 +1616,21 @@ def scaled_int8_quant_cuda(
|
||||
return x_q, scale, azp, static
|
||||
|
||||
|
||||
def fake_scaled_int8_quant(
|
||||
def _fake_scaled_int8_quant(
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
azp: Optional[torch.Tensor] = None,
|
||||
symmetric: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
|
||||
x_q = torch.ones(x.shape, dtype=torch.int8, device=x.device)
|
||||
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
|
||||
scale = torch.empty(
|
||||
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
||||
return x_q, scale, azp, False
|
||||
|
||||
|
||||
scaled_int8_quant.register_fake(fake_scaled_int8_quant)
|
||||
scaled_int8_quant.register_fake(_fake_scaled_int8_quant)
|
||||
|
||||
|
||||
######################################################
|
||||
@@ -1821,7 +1825,7 @@ def _fake_matmul(
|
||||
w_pc_max: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(x.shape[0], w.shape[0]),
|
||||
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
|
||||
dtype=out_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user