support e4m3 kvcache in qwen2 & add kv scaling facotr json (#2894)
Co-authored-by: bjmsong <bjmsong@126.com>
This commit is contained in:
@@ -22,7 +22,10 @@ import torch
|
||||
from torch import nn
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -39,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
kv_cache_scales_loader,
|
||||
)
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
Qwen2Config = None
|
||||
@@ -265,6 +271,29 @@ class Qwen2Model(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type,
|
||||
):
|
||||
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.layers[layer_idx].self_attn
|
||||
if hasattr(layer_self_attn.attn, "k_scale"):
|
||||
layer_self_attn.attn.k_scale = scaling_factor
|
||||
layer_self_attn.attn.v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Self attention has no KV cache scaling " "factor attribute!"
|
||||
)
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(nn.Module):
|
||||
|
||||
@@ -373,5 +402,8 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
|
||||
EntryClass = Qwen2ForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user