[XPU][CPU] Enable the native path of DeepSeek (#4086)
Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import sys
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@@ -8,6 +9,7 @@ from vllm.scalar_type import scalar_types
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,7 +92,20 @@ class GPTQConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 60
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 60
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -209,7 +224,20 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -371,7 +399,20 @@ class MarlinConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
|
||||
Reference in New Issue
Block a user