185 lines
4.8 KiB
Python
185 lines
4.8 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
from typing import Tuple
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from ._device import (
|
||
|
|
current_device,
|
||
|
|
device,
|
||
|
|
device_count,
|
||
|
|
get_device_capability,
|
||
|
|
get_device_name,
|
||
|
|
get_device_properties,
|
||
|
|
is_available,
|
||
|
|
is_bf16_supported,
|
||
|
|
set_device,
|
||
|
|
synchronize,
|
||
|
|
)
|
||
|
|
from .amp import (
|
||
|
|
get_amp_supported_dtype,
|
||
|
|
get_autocast_dtype,
|
||
|
|
is_autocast_enabled,
|
||
|
|
set_autocast_dtype,
|
||
|
|
set_autocast_enabled,
|
||
|
|
)
|
||
|
|
from .lazy_initialize import _is_in_bad_fork, _lazy_call, _lazy_init
|
||
|
|
from .memory import ( # caching_allocator_alloc,; caching_allocator_delete,
|
||
|
|
empty_cache,
|
||
|
|
get_allocator_backend,
|
||
|
|
max_memory_allocated,
|
||
|
|
max_memory_cached,
|
||
|
|
max_memory_reserved,
|
||
|
|
mem_get_info,
|
||
|
|
memory_allocated,
|
||
|
|
memory_cached,
|
||
|
|
memory_reserved,
|
||
|
|
memory_snapshot,
|
||
|
|
memory_stats,
|
||
|
|
memory_stats_as_nested_dict,
|
||
|
|
memory_summary,
|
||
|
|
reset_accumulated_memory_stats,
|
||
|
|
reset_max_memory_allocated,
|
||
|
|
reset_max_memory_cached,
|
||
|
|
reset_peak_memory_stats,
|
||
|
|
set_per_process_memory_fraction,
|
||
|
|
)
|
||
|
|
from .streams import Event, Stream, current_stream, default_stream, set_stream, stream
|
||
|
|
|
||
|
|
|
||
|
|
def init():
|
||
|
|
r"""Initialize PyTorch's VACC state. You may need to call
|
||
|
|
this explicitly if you are interacting with PyTorch via
|
||
|
|
its C API, as Python bindings for VACC functionality will not
|
||
|
|
be available until this initialization takes place. Ordinary users
|
||
|
|
should not need this, as all of PyTorch's VACC methods
|
||
|
|
automatically initialize VACC state on-demand.
|
||
|
|
|
||
|
|
Does nothing if the VACC state is already initialized.
|
||
|
|
"""
|
||
|
|
_lazy_init()
|
||
|
|
|
||
|
|
|
||
|
|
# default_generators is empty util _lazy_init() is called
|
||
|
|
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||
|
|
|
||
|
|
from .custom_ops import *
|
||
|
|
from .custom_qwen3_ops import *
|
||
|
|
from .random import * # noqa: F403
|
||
|
|
|
||
|
|
__all__ = [
|
||
|
|
"device",
|
||
|
|
"is_available",
|
||
|
|
"is_bf16_supported",
|
||
|
|
"current_device",
|
||
|
|
"set_device",
|
||
|
|
"device_count",
|
||
|
|
"get_device_properties",
|
||
|
|
"get_device_name",
|
||
|
|
"get_device_capability",
|
||
|
|
"synchronize",
|
||
|
|
"amp",
|
||
|
|
"get_amp_supported_dtype",
|
||
|
|
"is_autocast_enabled",
|
||
|
|
"set_autocast_enabled",
|
||
|
|
"get_autocast_dtype",
|
||
|
|
"set_autocast_dtype",
|
||
|
|
"_is_in_bad_fork",
|
||
|
|
"_lazy_call",
|
||
|
|
"get_rng_state",
|
||
|
|
"get_rng_state_all",
|
||
|
|
"set_rng_state",
|
||
|
|
"set_rng_state_all",
|
||
|
|
"manual_seed",
|
||
|
|
"manual_seed_all",
|
||
|
|
"seed",
|
||
|
|
"seed_all",
|
||
|
|
"initial_seed",
|
||
|
|
"set_stream",
|
||
|
|
"current_stream",
|
||
|
|
"default_generators",
|
||
|
|
"default_stream",
|
||
|
|
"stream",
|
||
|
|
"Stream",
|
||
|
|
"Event",
|
||
|
|
"mem_get_info",
|
||
|
|
"set_per_process_memory_fraction",
|
||
|
|
"empty_cache",
|
||
|
|
"memory_stats",
|
||
|
|
"memory_stats_as_nested_dict",
|
||
|
|
"reset_accumulated_memory_stats",
|
||
|
|
"reset_peak_memory_stats",
|
||
|
|
"reset_max_memory_allocated",
|
||
|
|
"reset_max_memory_cached",
|
||
|
|
"memory_allocated",
|
||
|
|
"max_memory_allocated",
|
||
|
|
"memory_reserved",
|
||
|
|
"max_memory_reserved",
|
||
|
|
"memory_cached",
|
||
|
|
"max_memory_cached",
|
||
|
|
"memory_snapshot",
|
||
|
|
"memory_summary",
|
||
|
|
"get_allocator_backend",
|
||
|
|
"rms_norm",
|
||
|
|
"RotaryPosEmbedding",
|
||
|
|
"scaled_dot_product_attention",
|
||
|
|
"scaled_dot_product_attention_cp_forward",
|
||
|
|
"scaled_dot_product_attention_cp_backward",
|
||
|
|
"swiglu",
|
||
|
|
"paged_attention",
|
||
|
|
"reshape_and_cache_attention",
|
||
|
|
"concat_and_cache_attention",
|
||
|
|
"w8a8_block_fp8_matmul",
|
||
|
|
"moe_expert_token_group_reassign",
|
||
|
|
"fused_mlp_mm_fp8",
|
||
|
|
"fused_mlp_fp8",
|
||
|
|
"fused_moe_preprocess",
|
||
|
|
"fused_residual_rmsnorm",
|
||
|
|
"parallel_embedding",
|
||
|
|
"all_reduce",
|
||
|
|
"all_gather",
|
||
|
|
"broadcast",
|
||
|
|
"fused_mlp_moe_with_rmsnorm",
|
||
|
|
"fuse_moe_decode_v2_allreduce",
|
||
|
|
"topk_topp",
|
||
|
|
"fused_mla",
|
||
|
|
"fused_mla_allreduce",
|
||
|
|
"fused_mlp_with_rmsnorm",
|
||
|
|
"fused_mlp_allreduce",
|
||
|
|
"ds3_sampler",
|
||
|
|
"sampler_v1",
|
||
|
|
"rejection_sampler",
|
||
|
|
"rejection_sampler_update_hidden_states",
|
||
|
|
"rejection_sampler_v1",
|
||
|
|
"fused_matmul_allgather",
|
||
|
|
"fused_mla_v2",
|
||
|
|
"fused_mla_allreduce_v2",
|
||
|
|
"mla_matmul_scale",
|
||
|
|
"mla_matmul",
|
||
|
|
"fused_mla_prefill_stage0",
|
||
|
|
"fused_mla_prefill_stage1",
|
||
|
|
"fused_mla_prefill_stage0_allreduce",
|
||
|
|
"fuse_moe_prefill_stage0",
|
||
|
|
"fuse_mla_mlp_v2_allreduce_decode",
|
||
|
|
"fuse_mla_moe_v2_allreduce_decode",
|
||
|
|
"fuse_mla_mlp_v2_allreduce_decode_layers",
|
||
|
|
"fuse_mla_moe_v2_allreduce_decode_layers",
|
||
|
|
"fuse_mla_mlp_v2_allreduce_decode_layers_v2",
|
||
|
|
"fuse_mla_moe_v2_allreduce_decode_layers_v2",
|
||
|
|
"fuse_mlp_qwen_int4",
|
||
|
|
"fuse_mlp_qwen_int4_reduce",
|
||
|
|
"w4a8_block_int4_matmul",
|
||
|
|
"fuse_atten_qwen3",
|
||
|
|
"fuse_atten_qwen2",
|
||
|
|
"qwen3_fuse_attention_moe_decode",
|
||
|
|
"fuse_mtp_stage0",
|
||
|
|
"fuse_mtp_allreduce",
|
||
|
|
"roll_out",
|
||
|
|
"fused_experts_int4_prefill",
|
||
|
|
"fuse_bge_embedding_stage1",
|
||
|
|
"l2_norm",
|
||
|
|
"fuse_mlp_vision",
|
||
|
|
"patch_merger_vision",
|
||
|
|
"fuse_atten_vit",
|
||
|
|
"apply_penalties",
|
||
|
|
]
|