Merge pull request #80 from liwei109/aicapx-quant
[fix]matmul not support cuda graph
This commit is contained in:
@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
|
from vllm_kunlun.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
|
||||||
from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder,
|
from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder,
|
||||||
Qwen3VLForConditionalGeneration,
|
Qwen3VLForConditionalGeneration,
|
||||||
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
|
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
|
||||||
|
|||||||
@@ -64,8 +64,9 @@ def apply(
|
|||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
w1_bias = layer.w13_bias,
|
w1_bias=getattr(layer, 'w13_bias', None),
|
||||||
w2_bias = layer.w2_bias)
|
w2_bias=getattr(layer, 'w2_bias', None),
|
||||||
|
)
|
||||||
|
|
||||||
UnquantizedFusedMoEMethod.apply = apply
|
UnquantizedFusedMoEMethod.apply = apply
|
||||||
|
|
||||||
|
|||||||
@@ -1616,17 +1616,21 @@ def scaled_int8_quant_cuda(
|
|||||||
return x_q, scale, azp, static
|
return x_q, scale, azp, static
|
||||||
|
|
||||||
|
|
||||||
def fake_scaled_int8_quant(
|
def _fake_scaled_int8_quant(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
azp: Optional[torch.Tensor] = None,
|
azp: Optional[torch.Tensor] = None,
|
||||||
symmetric: bool = True,
|
symmetric: bool = True,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
|
||||||
x_q = torch.ones(x.shape, dtype=torch.int8, device=x.device)
|
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
|
||||||
|
scale = torch.empty(
|
||||||
|
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
||||||
return x_q, scale, azp, False
|
return x_q, scale, azp, False
|
||||||
|
|
||||||
|
|
||||||
scaled_int8_quant.register_fake(fake_scaled_int8_quant)
|
scaled_int8_quant.register_fake(_fake_scaled_int8_quant)
|
||||||
|
|
||||||
|
|
||||||
######################################################
|
######################################################
|
||||||
@@ -1821,7 +1825,7 @@ def _fake_matmul(
|
|||||||
w_pc_max: torch.Tensor = None,
|
w_pc_max: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty(
|
return torch.empty(
|
||||||
(x.shape[0], w.shape[0]),
|
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
|
||||||
dtype=out_dtype,
|
dtype=out_dtype,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user