vlm: adapt internvl to VisionAttention (#6870)
This commit is contained in:
@@ -1,15 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
import math
|
import math
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache
|
||||||
from typing import Optional, Tuple
|
from typing import Any, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda, print_info_once
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
|
|||||||
from sglang.srt.layers.quantization import QuantizationConfig
|
from sglang.srt.layers.quantization import QuantizationConfig
|
||||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import add_prefix, logger
|
from sglang.srt.utils import add_prefix
|
||||||
|
|
||||||
ROTARY_EMBED_CLASSES = {
|
ROTARY_EMBED_CLASSES = {
|
||||||
"normal": apply_rotary_pos_emb,
|
"normal": apply_rotary_pos_emb,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def execute_once(func):
|
@dataclasses.dataclass
|
||||||
has_run = None
|
class SingletonCache:
|
||||||
|
data: Any = None
|
||||||
|
|
||||||
@wraps(func)
|
def set_data(self, value: Any) -> None:
|
||||||
def wrapper(*args, **kwargs):
|
self.data = value
|
||||||
nonlocal has_run
|
|
||||||
if not has_run:
|
|
||||||
func(*args, **kwargs)
|
|
||||||
has_run = True
|
|
||||||
|
|
||||||
return wrapper
|
def get_data(self) -> Optional[Any]:
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def empty(self) -> bool:
|
||||||
|
return self.get_data() is None
|
||||||
|
|
||||||
|
|
||||||
@execute_once
|
# TODO: requires real seqlens from images
|
||||||
def info_once(message: str):
|
@functools.lru_cache(maxsize=128)
|
||||||
logger.info(message)
|
def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
|
||||||
|
Caches the result based on these parameters.
|
||||||
|
"""
|
||||||
|
cu_seqlens = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen,
|
||||||
|
step=seqlen,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return cu_seqlens
|
||||||
|
|
||||||
|
|
||||||
class VisionSdpaAttention(nn.Module):
|
class VisionSdpaAttention(nn.Module):
|
||||||
@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens: Optional[torch.Tensor],
|
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
bsz: int,
|
||||||
|
seq_len: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
[b * s, h, head_size]
|
[b * s, h, head_size]
|
||||||
"""
|
"""
|
||||||
cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
|
if cu_seqlens is None:
|
||||||
|
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
||||||
|
elif isinstance(cu_seqlens, SingletonCache):
|
||||||
|
if cu_seqlens.empty():
|
||||||
|
cu_seqlens.set_data(
|
||||||
|
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
||||||
|
)
|
||||||
|
cu_seqlens = cu_seqlens.get_data()
|
||||||
|
|
||||||
|
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
|
||||||
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
max_seqlen = seq_lens.max().item()
|
max_seqlen = seq_lens.max().item()
|
||||||
output = flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
|
|||||||
if global_server_args_dict["mm_attention_backend"] is None:
|
if global_server_args_dict["mm_attention_backend"] is None:
|
||||||
if qkv_backend is None:
|
if qkv_backend is None:
|
||||||
qkv_backend = "sdpa"
|
qkv_backend = "sdpa"
|
||||||
info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||||
else:
|
else:
|
||||||
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
||||||
|
|
||||||
info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
||||||
|
|
||||||
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
||||||
head_dim=self.head_size,
|
head_dim=self.head_size,
|
||||||
@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
|
|||||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
||||||
qkv, _ = self.qkv_proj(x)
|
qkv, _ = self.qkv_proj(x)
|
||||||
|
|
||||||
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
# [s, b, head, head_dim_sum]
|
||||||
new_x_shape = qkv.size()[:-1] + (
|
new_x_shape = qkv.size()[:-1] + (
|
||||||
head,
|
head,
|
||||||
3 * self.hidden_size_per_attention_head,
|
self.q_size + 2 * self.kv_size,
|
||||||
)
|
)
|
||||||
qkv = qkv.view(*new_x_shape)
|
qkv = qkv.view(*new_x_shape)
|
||||||
|
|
||||||
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
||||||
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
# [s, b, head, head_size] --> [b, s, head, head_size]
|
# [s, b, head, head_size] --> [b, s, head, head_size]
|
||||||
q, k, v = [
|
q, k, v = [
|
||||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||||
@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
|
|||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
bsz=bsz,
|
bsz=bsz,
|
||||||
|
seq_len=s,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,21 +11,19 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==========================582====================================================
|
# ==========================582====================================================
|
||||||
|
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||||
from typing import Iterable, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
|
||||||
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
|
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
|
||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.managers.mm_utils import (
|
from sglang.srt.managers.mm_utils import (
|
||||||
MultiModalityDataPaddingPatternTokenPairs,
|
MultiModalityDataPaddingPatternTokenPairs,
|
||||||
@@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
|||||||
from sglang.utils import logger
|
from sglang.utils import logger
|
||||||
|
|
||||||
|
|
||||||
class FlashAttention(nn.Module):
|
|
||||||
"""Implement the scaled dot product attention with softmax.
|
|
||||||
Arguments
|
|
||||||
---------
|
|
||||||
softmax_scale: The temperature to use for the softmax attention.
|
|
||||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
|
||||||
runtime)
|
|
||||||
attention_dropout: The dropout rate to apply to the attention
|
|
||||||
(default: 0.0)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.softmax_scale = softmax_scale
|
|
||||||
self.dropout_p = attention_dropout
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
qkv,
|
|
||||||
causal=False,
|
|
||||||
max_s=None,
|
|
||||||
):
|
|
||||||
"""Implements the multihead softmax attention.
|
|
||||||
Arguments
|
|
||||||
---------
|
|
||||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
|
||||||
if unpadded: (nnz, 3, h, d)
|
|
||||||
"""
|
|
||||||
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
|
||||||
assert qkv.is_cuda
|
|
||||||
|
|
||||||
batch_size, seqlen, _, nheads, d = qkv.shape
|
|
||||||
if batch_size == 0 or seqlen == 0:
|
|
||||||
output_shape = (batch_size, seqlen, nheads, d)
|
|
||||||
return (
|
|
||||||
torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
|
|
||||||
q, k, v = qkv_reshaped.unbind(1)
|
|
||||||
|
|
||||||
max_s = seqlen
|
|
||||||
cu_seqlens = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen,
|
|
||||||
step=seqlen,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=qkv.device,
|
|
||||||
)
|
|
||||||
output_reshaped = flash_attn_varlen_func(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
softmax_scale=self.softmax_scale,
|
|
||||||
causal=causal,
|
|
||||||
)
|
|
||||||
output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
|
|
||||||
return output, None
|
|
||||||
|
|
||||||
|
|
||||||
class InternAttention(nn.Module):
|
class InternAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
quant_config: QuantizationConfig = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@@ -116,7 +51,19 @@ class InternAttention(nn.Module):
|
|||||||
self.head_dim = self.embed_dim // self.num_heads
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
|
||||||
|
self.attn = VisionAttention(
|
||||||
|
qkv_backend="fa3",
|
||||||
|
embed_dim=self.embed_dim,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
projection_size=self.embed_dim,
|
||||||
|
use_qkv_parallel=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
dropout=getattr(config, "dropout", 0.0),
|
||||||
|
proj_bias=getattr(config, "qkv_bias", True),
|
||||||
|
flatten_batch=False,
|
||||||
|
)
|
||||||
|
|
||||||
self.proj_drop = nn.Dropout(config.dropout)
|
self.proj_drop = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
self.qk_normalization = config.qk_normalization
|
self.qk_normalization = config.qk_normalization
|
||||||
@@ -125,36 +72,15 @@ class InternAttention(nn.Module):
|
|||||||
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.inner_attn = FlashAttention(softmax_scale=self.scale)
|
def forward(
|
||||||
|
|
||||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
||||||
|
|
||||||
def _flash_attn(
|
|
||||||
self,
|
self,
|
||||||
x,
|
hidden_states: torch.Tensor,
|
||||||
):
|
cu_seqlens: torch.Tensor,
|
||||||
qkv = self.qkv(x)
|
) -> torch.Tensor:
|
||||||
qkv = rearrange(
|
out = self.attn(hidden_states, cu_seqlens=cu_seqlens)
|
||||||
qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
|
outs = self.proj_drop(out)
|
||||||
)
|
|
||||||
|
|
||||||
if self.qk_normalization:
|
|
||||||
q, k, v = qkv.unbind(2)
|
|
||||||
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
|
||||||
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
|
||||||
qkv = torch.stack([q, k, v], dim=2)
|
|
||||||
|
|
||||||
context, _ = self.inner_attn(
|
|
||||||
qkv,
|
|
||||||
)
|
|
||||||
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
|
|
||||||
outs = self.proj_drop(outs)
|
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
x = self._flash_attn(hidden_states)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class InternVisionEmbeddings(nn.Module):
|
class InternVisionEmbeddings(nn.Module):
|
||||||
def __init__(self, config: PretrainedConfig):
|
def __init__(self, config: PretrainedConfig):
|
||||||
@@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.FloatTensor,
|
torch.FloatTensor,
|
||||||
Optional[torch.FloatTensor],
|
Optional[torch.FloatTensor],
|
||||||
@@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hidden_states = hidden_states + self.drop_path1(
|
hidden_states = hidden_states + self.drop_path1(
|
||||||
self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
|
self.attn(
|
||||||
|
self.norm1(hidden_states).to(hidden_states.dtype), cu_seqlens=cu_seqlens
|
||||||
|
)
|
||||||
|
* self.ls1
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + self.drop_path2(
|
hidden_states = hidden_states + self.drop_path2(
|
||||||
@@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module):
|
|||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
cu_seqlens = SingletonCache()
|
||||||
|
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
layer_outputs = encoder_layer(
|
layer_outputs = encoder_layer(hidden_states, cu_seqlens=cu_seqlens)
|
||||||
hidden_states,
|
|
||||||
)
|
|
||||||
hidden_states = layer_outputs
|
hidden_states = layer_outputs
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
@@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
if "vision_model" in name:
|
||||||
|
# adapt to VisionAttention
|
||||||
|
name = name.replace(r"attn.", r"attn.attn.")
|
||||||
|
name = name.replace(r"qkv.", r"qkv_proj.")
|
||||||
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
@@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module):
|
|||||||
param, "weight_loader", default_weight_loader
|
param, "weight_loader", default_weight_loader
|
||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
unloaded_params = params_dict.keys() - loaded_params
|
||||||
|
if unloaded_params:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
||||||
|
)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
EntryClass = InternVLChatModel
|
EntryClass = InternVLChatModel
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import base64
|
|||||||
import builtins
|
import builtins
|
||||||
import ctypes
|
import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
import io
|
import io
|
||||||
import ipaddress
|
import ipaddress
|
||||||
@@ -1386,6 +1387,11 @@ def print_warning_once(msg: str) -> None:
|
|||||||
logger.warning(msg, stacklevel=2)
|
logger.warning(msg, stacklevel=2)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(None)
|
||||||
|
def print_info_once(msg: str) -> None:
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
|
||||||
def get_device_name(device_id: int = 0) -> str:
|
def get_device_name(device_id: int = 0) -> str:
|
||||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
return torch.cuda.get_device_name(device_id)
|
return torch.cuda.get_device_name(device_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user