This commit is contained in:
starkwj
2026-02-05 19:36:06 +08:00
parent 070bfa4a73
commit e273ef01b8
131 changed files with 28539 additions and 2 deletions

View File

@@ -8,6 +8,9 @@ import vllm.envs as envs
from vllm.logger import init_logger
# fix bfloat16 double size issue
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
logger = init_logger(__name__)
class KunlunPlatform(Platform):
@@ -152,8 +155,10 @@ class KunlunPlatform(Platform):
if parallel_config.worker_cls == "auto":
if vllm_config.speculative_config:
if envs.VLLM_USE_V1:
# parallel_config.worker_cls = \
# "vllm.v1.worker.gpu_worker.Worker"
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
"vllm_kunlun.v1.worker.worker_v1.KunlunWorker"
else:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
@@ -162,8 +167,10 @@ class KunlunPlatform(Platform):
else:
print(f"envs.VLLM_USE_V1 = {envs.VLLM_USE_V1}")
if envs.VLLM_USE_V1:
# parallel_config.worker_cls = \
# "vllm.v1.worker.gpu_worker.Worker"
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
"vllm_kunlun.v1.worker.worker_v1.KunlunWorker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"