Support cuda graph for DP attention (#2061)

This commit is contained in:
Ke Bao
2024-11-18 08:29:20 +08:00
committed by GitHub
parent 11f881d173
commit 62832bb272
9 changed files with 88 additions and 26 deletions

View File

@@ -31,7 +31,7 @@ from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer
@torch.inference_mode()
@torch.no_grad()
def normal_text(args):
t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
@@ -69,7 +69,7 @@ def normal_text(args):
print(output_str)
@torch.inference_mode()
@torch.no_grad()
def synthetic_tokens(args):
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True