Add INT8 support MTP NextN function (#3911)
This commit is contained in:
@@ -30,6 +30,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
block_quant_to_tensor_quant,
|
block_quant_to_tensor_quant,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.int8_utils import (
|
||||||
|
block_dequant as int8_block_dequant,
|
||||||
|
)
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
@@ -291,6 +294,23 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
weight, weight_scale, weight_block_size
|
weight, weight_scale, weight_block_size
|
||||||
)
|
)
|
||||||
self_attn.w_scale = scale
|
self_attn.w_scale = scale
|
||||||
|
if w.dtype == torch.int8:
|
||||||
|
if hasattr(self.quant_config, "weight_block_size"):
|
||||||
|
# block-wise int8 need it
|
||||||
|
weight_block_size = self.quant_config.weight_block_size
|
||||||
|
if weight_block_size is not None:
|
||||||
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||||
|
weight = w
|
||||||
|
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||||
|
w = int8_block_dequant(
|
||||||
|
weight, weight_scale, weight_block_size
|
||||||
|
).to(torch.bfloat16)
|
||||||
|
else:
|
||||||
|
# channel-wise int8 need it
|
||||||
|
assert hasattr(self_attn.kv_b_proj, "weight_scale")
|
||||||
|
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
||||||
|
torch.bfloat16
|
||||||
|
)
|
||||||
w_kc, w_vc = w.unflatten(
|
w_kc, w_vc = w.unflatten(
|
||||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user