Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]add coding benchmark for speculative decoding #15303

Merged
merged 10 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions benchmarks/benchmark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,66 @@ def sample(
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests


# -----------------------------------------------------------------------------
# Instruct Coder Dataset Implementation
# -----------------------------------------------------------------------------


class InstructCoderDataset(HuggingFaceDataset):
"""
InstructCoder Dataset.
https://huggingface.co/datasets/likaixin/InstructCoder

InstructCoder is the dataset designed for general code editing.
It consists of 114,239 instruction-input-output triplets,
and covers multiple distinct code editing scenario.
"""

DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
DEFAULT_NUM_REQUESTS = 1000
INSTRUCT_CODER_DATASET_PATH = "likaixin/InstructCoder"

def __init__(
self,
**kwargs,
) -> None:
super().__init__(**kwargs)
if self.dataset_path != self.INSTRUCT_CODER_DATASET_PATH:
raise ValueError(f"Only support likaixin/InstructCoder dataset.\
This data path {self.dataset_path} is not valid.")
if self.dataset_subset is None and self.dataset_split != "train":
raise ValueError("Dataset split must be 'train'.")

def load_data(self) -> None:
dataset = load_dataset(
self.dataset_path,
name=self.dataset_subset,
split=self.dataset_split,
streaming=True,
)
self.data = dataset.shuffle(seed=self.random_seed)

def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = f"{item['instruction']}:\n{item['input']}"
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
16 changes: 11 additions & 5 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@
from argparse import ArgumentParser as FlexibleArgumentParser

from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json

MILLISECONDS_TO_SECONDS_CONVERSION = 1000
Expand Down Expand Up @@ -588,9 +589,14 @@ def main(args: argparse.Namespace):
elif args.dataset_name == "hf":
# Choose between VisionArenaDataset
# and HuggingFaceDataset based on provided parameters.
dataset_class = (VisionArenaDataset if args.dataset_path
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
and args.hf_subset is None else HuggingFaceDataset)
dataset_class = HuggingFaceDataset
if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
assert args.hf_subset is None, "VisionArenaDataset needs hf_subset to be None." #noqa: E501
dataset_class = VisionArenaDataset
elif args.dataset_path == "likaixin/InstructCoder":
dataset_class = InstructCoderDataset
args.hf_split = "train"

input_requests = dataset_class(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
Expand Down
43 changes: 27 additions & 16 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import torch
import uvloop
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
Expand Down Expand Up @@ -300,6 +301,7 @@ def get_requests(args, tokenizer):
"input_len": args.input_len,
"output_len": args.output_len,
}

if args.dataset_path is None or args.dataset_name == "random":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
Expand All @@ -317,17 +319,21 @@ def get_requests(args, tokenizer):
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
if args.backend != "vllm-chat":
raise ValueError(
"hf datasets only are supported by vllm-chat backend")
# Choose between VisionArenaDataset and HuggingFaceDataset based on
# provided parameters.
dataset_cls = (VisionArenaDataset if args.dataset_path
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
and args.hf_subset is None else HuggingFaceDataset)
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
if args.args.backend == "vllm-chat":
raise ValueError(
"hf datasets only are supported by vllm-chat backend")
# Choose between VisionArenaDataset and HuggingFaceDataset based on
# provided parameters.
dataset_cls = (VisionArenaDataset if args.dataset_path
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
and args.hf_subset is None else HuggingFaceDataset)
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path == "likaixin/InstructCoder":
dataset_cls = InstructCoderDataset
common_kwargs['dataset_split'] = "train"

else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
Expand Down Expand Up @@ -462,9 +468,14 @@ def validate_args(args):
warnings.warn("--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
stacklevel=2)
elif args.dataset_name == "hf" and args.backend != "vllm-chat":
raise ValueError(
"When --dataset-name is 'hf', backend must be 'vllm-chat'")
elif args.dataset_name == "hf":
if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501
elif args.dataset_path == "likaixin/InstructCoder":
assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501
else:
raise ValueError(
f"{args.dataset_path} is not supported by hf dataset.")

# --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != 'random' and args.random_range_ratio is not None:
Expand Down