Enable CPU device on SGLang (#2806)
This commit is contained in:
@@ -40,6 +40,10 @@ srt_xpu = ["sglang[runtime_common]"]
|
||||
#For Intel Gaudi(device : hpu) follow the installation guide
|
||||
#https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
|
||||
srt_hpu = ["sglang[runtime_common]"]
|
||||
# CPU: currently, there are no pre-built vllm wheels for CPU.
|
||||
# To install vllm for CPU, please follow the instruction here:
|
||||
# https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html
|
||||
srt_cpu = ["sglang[runtime_common]", "torch"]
|
||||
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
@@ -57,11 +61,13 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
dev_hip = ["sglang[all_hip]", "sglang[test]"]
|
||||
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
|
||||
dev_hpu = ["sglang[all_hpu]", "sglang[test]"]
|
||||
dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/sgl-project/sglang"
|
||||
|
||||
@@ -10,7 +10,7 @@ class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
if device in ["cuda", "xpu", "hpu"]:
|
||||
if device in ["cuda", "xpu", "hpu", "cpu"]:
|
||||
self.device_type = device
|
||||
else:
|
||||
raise RuntimeError(f"Not supported device type: {device}")
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
|
||||
@@ -44,3 +45,71 @@ def fused_moe_forward_native(
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||
|
||||
|
||||
def moe_forward_native(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
torch_native=True,
|
||||
)
|
||||
|
||||
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
|
||||
len_experts = layer.num_experts
|
||||
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
|
||||
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
|
||||
layer_w13_weight = layer.w13_weight[i]
|
||||
layer_w2_weight = layer.w2_weight[i]
|
||||
|
||||
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
||||
gate_up = SiluAndMul()(gate_up)
|
||||
expert_out = F.linear(gate_up, layer_w2_weight)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
new_x = torch.empty_like(outs)
|
||||
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weights.dtype)
|
||||
.mul_(topk_weights.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
@@ -19,7 +19,10 @@ from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
||||
|
||||
is_hip_flag = False
|
||||
if not is_hip():
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
if torch.cuda.is_available():
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
else:
|
||||
sgl_moe_align_block_size = None
|
||||
|
||||
is_hip_flag = False
|
||||
else:
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.distributed import (
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -185,8 +186,31 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
@@ -15,6 +15,15 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding,
|
||||
_rotate_gptj,
|
||||
_rotate_neox,
|
||||
_yarn_find_correction_range,
|
||||
_yarn_linear_ramp_mask,
|
||||
get_rope,
|
||||
yarn_get_mscale,
|
||||
)
|
||||
|
||||
|
||||
class MRotaryEmbedding:
|
||||
@@ -110,3 +119,242 @@ class MRotaryEmbedding:
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
|
||||
# TODO: in the DeepseekScalingRotaryEmbedding class defined in vllm,
|
||||
# the device has been hard-coded to "cuda" in these two places:
|
||||
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L646
|
||||
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L665
|
||||
# We port the related code to this file to make it compatible with the CPU version.
|
||||
# We will add an optimized rotary embedding kernel for CPU and will remove the ported code then.
|
||||
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with YaRN method.
|
||||
|
||||
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
device: Optional[str] = None,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale))
|
||||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
||||
* attn_factor
|
||||
)
|
||||
self.device = device
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base ** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
|
||||
/ self.rotary_dim
|
||||
)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
self.rotary_dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (
|
||||
1
|
||||
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
||||
) * self.extrapolation_factor
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||
+ inv_freq_extrapolation * inv_freq_mask
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(
|
||||
self.max_position_embeddings * self.scaling_factor,
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos() * self.mscale
|
||||
sin = freqs.sin() * self.mscale
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
print("Cache shape", cache.shape)
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
|
||||
cos_sin = self.cos_sin_cache[
|
||||
torch.add(positions, offsets) if offsets is not None else positions
|
||||
]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if self.is_neox_style:
|
||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# shape [batch_size, seq_len].
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
|
||||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
|
||||
|
||||
def get_rope_cpu(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: Optional[str] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling_args,
|
||||
dtype,
|
||||
)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
assert rope_scaling is not None
|
||||
scaling_type = rope_scaling["rope_type"]
|
||||
assert (
|
||||
scaling_type == "deepseek_yarn"
|
||||
), "Only deepseek_yarn is supported for CPU for now"
|
||||
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
)
|
||||
}
|
||||
extra_kwargs["device"] = device
|
||||
rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
|
||||
def get_rope_wrapper(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
if device != "cpu":
|
||||
return get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
)
|
||||
|
||||
return get_rope_cpu(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
device,
|
||||
)
|
||||
|
||||
@@ -65,6 +65,7 @@ global_server_args_dict = {
|
||||
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"device": ServerArgs.device,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -317,6 +317,8 @@ class Scheduler:
|
||||
self.last_decode_stats_tic = time.time()
|
||||
self.stream_interval = server_args.stream_interval
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
if self.device == "cpu":
|
||||
self.current_stream.synchronize = lambda: None # No-op for CPU
|
||||
|
||||
# Session info
|
||||
self.sessions: Dict[str, Session] = {}
|
||||
|
||||
@@ -82,6 +82,8 @@ class TpModelWorkerClient:
|
||||
self.forward_thread.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
||||
if self.device == "cpu":
|
||||
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
||||
|
||||
def get_worker_info(self):
|
||||
return self.worker.get_worker_info()
|
||||
|
||||
@@ -106,8 +106,10 @@ class ModelRunner:
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
):
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
self.server_args.attention_backend = "triton"
|
||||
# TODO: add MLA optimization on CPU
|
||||
if self.server_args.device != "cpu":
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
self.server_args.attention_backend = "triton"
|
||||
|
||||
if self.server_args.enable_double_sparsity:
|
||||
logger.info(
|
||||
@@ -164,6 +166,7 @@ class ModelRunner:
|
||||
"enable_nan_detection": server_args.enable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"device": server_args.device,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -221,6 +224,8 @@ class ModelRunner:
|
||||
backend = "gloo"
|
||||
elif self.device == "hpu":
|
||||
backend = "hccl"
|
||||
elif self.device == "cpu":
|
||||
backend = "gloo"
|
||||
|
||||
if not self.server_args.enable_p2p_check:
|
||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||
@@ -269,7 +274,8 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
# This can reduce thread conflicts and speed up weight loading.
|
||||
torch.set_num_threads(1)
|
||||
if self.device != "cpu":
|
||||
torch.set_num_threads(1)
|
||||
if self.device == "cuda":
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
logger.info(
|
||||
|
||||
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -271,13 +272,14 @@ class DeepseekV2Attention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
self.rotary_emb = get_rope(
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=global_server_args_dict["device"],
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
|
||||
@@ -392,7 +392,7 @@ class ServerArgs:
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda", "xpu", "hpu"],
|
||||
choices=["cuda", "xpu", "hpu", "cpu"],
|
||||
help="The device type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -223,6 +223,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
|
||||
|
||||
elif device == "cpu":
|
||||
# TODO: rename the variables in the current function to be not GPU specific
|
||||
free_gpu_memory = psutil.virtual_memory().available
|
||||
|
||||
if distributed:
|
||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
||||
torch.device(device, gpu_id)
|
||||
|
||||
Reference in New Issue
Block a user