Qwen FP8/NVFP4 ModelOPT Quantization support (#7912)
Co-authored-by: Jingyu Xin <jingyux@nvidia.com>
This commit is contained in:
@@ -517,6 +517,39 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
return ["hf_quant_config.json"]
|
return ["hf_quant_config.json"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def common_group_size(cfg: dict) -> int:
|
||||||
|
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
||||||
|
sizes = set()
|
||||||
|
|
||||||
|
# Top-level and 'quantization' block
|
||||||
|
v = cfg.get("group_size")
|
||||||
|
if isinstance(v, int):
|
||||||
|
sizes.add(v)
|
||||||
|
q = cfg.get("quantization")
|
||||||
|
if isinstance(q, dict):
|
||||||
|
v = q.get("group_size")
|
||||||
|
if isinstance(v, int):
|
||||||
|
sizes.add(v)
|
||||||
|
|
||||||
|
# config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
|
||||||
|
for g in (cfg.get("config_groups") or {}).values():
|
||||||
|
if isinstance(g, dict):
|
||||||
|
v = g.get("group_size")
|
||||||
|
if isinstance(v, int):
|
||||||
|
sizes.add(v)
|
||||||
|
for sub in g.values():
|
||||||
|
if isinstance(sub, dict):
|
||||||
|
v = sub.get("group_size")
|
||||||
|
if isinstance(v, int):
|
||||||
|
sizes.add(v)
|
||||||
|
|
||||||
|
if not sizes:
|
||||||
|
raise ValueError("No group_size found in config.")
|
||||||
|
if len(sizes) > 1:
|
||||||
|
raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
|
||||||
|
return next(iter(sizes))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
||||||
# Handle two different config formats:
|
# Handle two different config formats:
|
||||||
@@ -549,7 +582,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
else:
|
else:
|
||||||
kv_cache_quant_algo = "auto"
|
kv_cache_quant_algo = "auto"
|
||||||
|
|
||||||
group_size = config.get("group_size")
|
group_size = ModelOptFp4Config.common_group_size(config)
|
||||||
exclude_modules = config.get("ignore", [])
|
exclude_modules = config.get("ignore", [])
|
||||||
else:
|
else:
|
||||||
# Fall back to nested format (hf_quant_config.json - legacy format)
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
||||||
@@ -559,7 +592,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
||||||
if not kv_cache_quant_algo:
|
if not kv_cache_quant_algo:
|
||||||
kv_cache_quant_algo = "auto"
|
kv_cache_quant_algo = "auto"
|
||||||
group_size = quant_config.get("group_size")
|
group_size = ModelOptFp4Config.common_group_size(config)
|
||||||
exclude_modules = quant_config.get("exclude_modules", [])
|
exclude_modules = quant_config.get("exclude_modules", [])
|
||||||
except (ValueError, KeyError):
|
except (ValueError, KeyError):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
|||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
maybe_remap_kv_scale_name,
|
||||||
|
)
|
||||||
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
||||||
from sglang.srt.models.qwen2 import Qwen2Model
|
from sglang.srt.models.qwen2 import Qwen2Model
|
||||||
from sglang.srt.utils import add_prefix, is_cuda
|
from sglang.srt.utils import add_prefix, is_cuda
|
||||||
@@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if "scale" in name:
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user