diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 6f7375299..43303bf4c 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -17,15 +17,23 @@ class ChatTemplate: image_token: str = "" style: ChatTemplateStyle = ChatTemplateStyle.PLAIN - def get_prefix_and_suffix(self, role: str, hist_messages: List[Dict]) -> Tuple[str, str]: + def get_prefix_and_suffix( + self, role: str, hist_messages: List[Dict] + ) -> Tuple[str, str]: prefix, suffix = self.role_prefix_and_suffix.get(role, ("", "")) - + if self.style == ChatTemplateStyle.LLAMA2: if role == "system" and not hist_messages: user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", "")) - system_prefix, system_suffix = self.role_prefix_and_suffix.get("system", ("", "")) + system_prefix, system_suffix = self.role_prefix_and_suffix.get( + "system", ("", "") + ) return (user_prefix + system_prefix, system_suffix) - elif role == "user" and len(hist_messages) == 1 and hist_messages[0]["content"] is not None: + elif ( + role == "user" + and len(hist_messages) == 1 + and hist_messages[0]["content"] is not None + ): return ("", suffix) return prefix, suffix @@ -171,6 +179,19 @@ register_chat_template( ) ) +register_chat_template( + ChatTemplate( + name="gemma-it", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", ""), + "user": ("user\n", "\n"), + "assistant": ("model\n", "\n"), + }, + style=ChatTemplateStyle.PLAIN, + ) +) + @register_chat_template_matching_function def match_vicuna(model_path: str): @@ -211,6 +232,13 @@ def match_chat_yi(model_path: str): return get_chat_template("yi") +@register_chat_template_matching_function +def match_gemma_it(model_path: str): + model_path = model_path.lower() + if "gemma" in model_path and "it" in model_path: + return get_chat_template("gemma-it") + + if __name__ == "__main__": messages = [ {"role": "system", "content": None}, # None means default diff --git a/python/sglang/srt/layers/context_flashattention_nopad.py b/python/sglang/srt/layers/context_flashattention_nopad.py index ef254a1b2..2ac3d39e9 100644 --- a/python/sglang/srt/layers/context_flashattention_nopad.py +++ b/python/sglang/srt/layers/context_flashattention_nopad.py @@ -129,7 +129,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 0c5cebf5a..62167a582 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -181,19 +181,20 @@ def extend_attention_fwd( k_buffer, v_buffer: (prefix + extend) tensors in mem_manager """ - if CUDA_CAPABILITY[0] >= 8: - BLOCK_M, BLOCK_N = 128, 128 - else: - BLOCK_M, BLOCK_N = 64, 64 - Lq, Lk, Lv, Lo = ( q_extend.shape[-1], k_extend.shape[-1], v_extend.shape[-1], o_extend.shape[-1], ) + assert Lq == Lk and Lk == Lv and Lv == Lo - assert Lq in {16, 32, 64, 128} + assert Lq in {16, 32, 64, 128, 256} + + if CUDA_CAPABILITY[0] >= 8: + BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) + else: + BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) sm_scale = 1.0 / (Lq**0.5) batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 9b9525ac1..457cabc88 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -1,15 +1,9 @@ -from typing import List - import torch from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata from torch import nn -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) class RadixAttention(nn.Module): @@ -21,9 +15,9 @@ class RadixAttention(nn.Module): self.head_dim = head_dim self.layer_id = layer_id - from sglang.srt.managers.router.model_runner import global_model_mode + from sglang.srt.managers.router.model_runner import global_server_args - self.use_flashinfer = "flashinfer" in global_model_mode + self.use_flashinfer = "flashinfer" in global_server_args.model_mode if self.use_flashinfer: self.prefill_forward = self.prefill_forward_flashinfer diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index 8ac4ed959..c8a80fd32 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -5,6 +5,14 @@ import torch import triton import triton.language as tl from sglang.srt.utils import wrap_kernel_launcher +from sglang.srt.managers.router.model_runner import global_server_args + +if global_server_args.attention_reduce_in_fp32: + REDUCE_TRITON_TYPE = tl.float32 + REDUCE_TORCH_TYPE = torch.float32 +else: + REDUCE_TRITON_TYPE = tl.float16 + REDUCE_TORCH_TYPE = torch.float16 @triton.jit @@ -49,7 +57,7 @@ def _fwd_kernel_stage1( block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) + q = tl.load(Q + off_q + start_mark).to(REDUCE_TRITON_TYPE) offs_n_new = cur_batch_start_index + offs_n k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, @@ -65,7 +73,7 @@ def _fwd_kernel_stage1( K_Buffer + offs_buf_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0, - ) + ).to(REDUCE_TRITON_TYPE) att_value = tl.sum(q[None, :] * k, 1) att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) @@ -161,7 +169,7 @@ def _token_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} sm_scale = 1.0 / (Lk**0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -299,7 +307,7 @@ def token_attention_fwd( ): if att_m is None: att_m = torch.empty( - (q.shape[-2], total_num_tokens), dtype=q.dtype, device="cuda" + (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" ) _token_att_m_fwd( diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 40ae6dd98..a651f2c7a 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -57,17 +57,19 @@ class ModelRpcServer(rpyc.Service): # Init model and tokenizer self.model_config = ModelConfig( - server_args.model_path, server_args.trust_remote_code, context_length=server_args.context_length + server_args.model_path, + server_args.trust_remote_code, + context_length=server_args.context_length, ) self.model_runner = ModelRunner( - self.model_config, - server_args.mem_fraction_static, - tp_rank, - server_args.tp_size, - port_args.nccl_port, - server_args.load_format, - server_args.trust_remote_code, - server_args.model_mode, + model_config=self.model_config, + mem_fraction_static=server_args.mem_fraction_static, + tp_rank=tp_rank, + tp_size=server_args.tp_size, + nccl_port=port_args.nccl_port, + server_args=server_args, + load_format=server_args.load_format, + trust_remote_code=server_args.trust_remote_code, ) if is_multimodal_model(server_args.model_path): self.processor = get_processor( @@ -435,7 +437,7 @@ class ModelRpcServer(rpyc.Service): # If logprob_start_len > 0, then first logprob_start_len prompt tokens # will be ignored. prompt_token_len = len(req.logprob) - token_ids = req.input_ids[-prompt_token_len :] + [next_token_ids[i]] + token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]] token_logprobs = req.logprob + [last_logprobs[i]] req.token_logprob = list(zip(token_ids, token_logprobs)) if req.logprob_start_len == 0: @@ -553,8 +555,7 @@ class ModelRpcServer(rpyc.Service): "completion_tokens": len(req.input_ids) + len(req.output_ids) - req.prompt_tokens, - "completion_tokens_wo_jump_forward": - req.completion_tokens_wo_jump_forward + "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, } if req.return_logprob: meta_info["prompt_logprob"] = req.logprob diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 7e4a393f1..20d163947 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -3,7 +3,6 @@ import logging from dataclasses import dataclass from functools import lru_cache from pathlib import Path -from typing import List import numpy as np import torch @@ -23,8 +22,8 @@ QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig} logger = logging.getLogger("model_runner") -# for model_mode -global_model_mode: List[str] = [] +# for server args in model endpoints +global_server_args = None @lru_cache() @@ -81,7 +80,6 @@ class InputMetadata: return_logprob: bool = False # for flashinfer - use_flashinfer: bool = False qo_indptr: torch.Tensor = None kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None @@ -224,8 +222,7 @@ class InputMetadata: if forward_mode == ForwardMode.EXTEND: ret.init_extend_args() - ret.use_flashinfer = "flashinfer" in model_runner.model_mode - if ret.use_flashinfer: + if "flashinfer" in global_server_args.model_mode: ret.init_flashinfer_args(tp_size) return ret @@ -239,9 +236,9 @@ class ModelRunner: tp_rank, tp_size, nccl_port, + server_args, load_format="auto", trust_remote_code=True, - model_mode: List[str] = (), ): self.model_config = model_config self.mem_fraction_static = mem_fraction_static @@ -250,10 +247,9 @@ class ModelRunner: self.nccl_port = nccl_port self.load_format = load_format self.trust_remote_code = trust_remote_code - self.model_mode = model_mode - global global_model_mode - global_model_mode = model_mode + global global_server_args + global_server_args = server_args # Init torch distributed torch.cuda.set_device(self.tp_rank) @@ -319,9 +315,7 @@ class ModelRunner: available_gpu_memory = get_available_gpu_memory( self.tp_rank, distributed=self.tp_size > 1 ) * (1 << 30) - head_dim = ( - self.model_config.hidden_size // self.model_config.num_attention_heads - ) + head_dim = self.model_config.head_dim head_num = self.model_config.num_key_value_heads // self.tp_size cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 rest_memory = available_gpu_memory - total_gpu_memory * ( @@ -346,8 +340,7 @@ class ModelRunner: self.max_total_num_token, dtype=torch.float16, head_num=self.model_config.num_key_value_heads // self.tp_size, - head_dim=self.model_config.hidden_size - // self.model_config.num_attention_heads, + head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, ) diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index 504f499dc..e27d5b63f 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -1,7 +1,5 @@ -import os -from typing import Optional, Union +from typing import Optional -import torch from sglang.srt.hf_transformers_utils import get_config, get_context_length @@ -17,14 +15,18 @@ class ModelConfig: self.trust_remote_code = trust_remote_code self.revision = revision self.hf_config = get_config(self.path, trust_remote_code, revision) - + if context_length is not None: self.context_len = context_length else: self.context_len = get_context_length(self.hf_config) # Unify the config keys for hf_config - self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads + self.head_dim = getattr( + self.hf_config, + "head_dim", + self.hf_config.hidden_size // self.hf_config.num_attention_heads, + ) self.num_attention_heads = self.hf_config.num_attention_heads self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None) if self.num_key_value_heads is None: diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py new file mode 100644 index 000000000..e2f4492dd --- /dev/null +++ b/python/sglang/srt/models/gemma.py @@ -0,0 +1,340 @@ +# Adapted from: +# https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py +"""Inference-only Gemma model compatible with HuggingFace weights.""" +from typing import Optional, Tuple + +import torch +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from torch import nn +from transformers import GemmaConfig +from vllm.config import LoRAConfig +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) + + +class GemmaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, linear_method=linear_method + ) + self.act_fn = GeluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GemmaAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + layer_id: int = 0, + max_position_embeddings: int = 8192, + rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class GemmaDecoderLayer(nn.Module): + def __init__( + self, + config: GemmaConfig, + layer_id: int = 0, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GemmaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + layer_id=layer_id, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + linear_method=linear_method, + ) + self.mlp = GemmaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class GemmaModel(nn.Module): + def __init__( + self, + config: GemmaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + GemmaDecoderLayer(config, i, linear_method) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + skip_embed: bool = False, + ) -> torch.Tensor: + if not skip_embed: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_ids + + # Normalize the embedding by sqrt(hidden_size) + hidden_states *= self.config.hidden_size**0.5 + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class GemmaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: GemmaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + del lora_config # Unused. + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = GemmaModel(config, linear_method) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + skip_embed: bool = False, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + return self.logits_processor( + input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata + ) + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params = set() + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # GemmaRMSNorm is different from Llama's in that it multiplies + # (1 + weight) to the output, instead of just weight. + if "norm.weight" in name: + loaded_weight += 1.0 + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + raise RuntimeError( + "Some weights are not initialized from checkpoints: " + f"{unloaded_params}" + ) + + +EntryClass = GemmaForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 73583d1fa..95ce19087 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -28,6 +28,7 @@ class ServerArgs: log_level: str = "info" disable_regex_jump_forward: bool = False disable_disk_cache: bool = False + attention_reduce_in_fp32: bool = False def __post_init__(self): if self.tokenizer_path is None: @@ -189,6 +190,11 @@ class ServerArgs: action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) + parser.add_argument( + "--attention-reduce-in-fp32", + action="store_true", + help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace):