[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
@ -36,6 +36,10 @@ def parse_args():
|
||||
help="downloaded from the eagle repo " \
|
||||
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
|
||||
)
|
||||
parser.add_argument("--method",
|
||||
type=str,
|
||||
default='eagle',
|
||||
choices=['eagle', 'eagle3'])
|
||||
parser.add_argument("--max_num_seqs", type=int, default=8)
|
||||
parser.add_argument("--num_prompts", type=int, default=80)
|
||||
parser.add_argument("--num_spec_tokens", type=int, default=2)
|
||||
@ -53,7 +57,13 @@ def main():
|
||||
args = parse_args()
|
||||
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
if args.method == 'eagle':
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
elif args.method == 'eagle3':
|
||||
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
else:
|
||||
raise ValueError(f"unknown method: {args.method}")
|
||||
|
||||
max_model_len = 2048
|
||||
|
||||
@ -81,7 +91,7 @@ def main():
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
gpu_memory_utilization=0.8,
|
||||
speculative_config={
|
||||
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
|
||||
"method": args.method,
|
||||
"model": eagle_dir,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"draft_tensor_parallel_size": args.draft_tp,
|
||||
|
||||
Reference in New Issue
Block a user